diff options
author | Eric Cox <eric.cox@mongodb.com> | 2022-06-22 02:26:08 +0000 |
---|---|---|
committer | Eric Cox <eric.cox@mongodb.com> | 2022-06-22 02:26:08 +0000 |
commit | e24af2270e6bc1d435845c8fdd02a1eb24155da2 (patch) | |
tree | 630fdee303fd190f8a3e75779910ba9e45c90b55 /src | |
parent | ecf315a92efb6283a4c8f3fd079dd81398563fa6 (diff) | |
parent | 816918e7ebb03cd74fdfb935590d14fc9e8d0210 (diff) | |
download | mongo-e24af2270e6bc1d435845c8fdd02a1eb24155da2.tar.gz |
Merge branch 'v6.0' of github.com:10gen/mongo into v6.0
Diffstat (limited to 'src')
54 files changed, 1944 insertions, 308 deletions
diff --git a/src/mongo/base/error_codes.yml b/src/mongo/base/error_codes.yml index 523c52065cf..0b3150042bd 100644 --- a/src/mongo/base/error_codes.yml +++ b/src/mongo/base/error_codes.yml @@ -490,6 +490,7 @@ error_codes: - {code: 374, name: TransactionAPIMustRetryTransaction, categories: [InternalOnly]} - {code: 375, name: TransactionAPIMustRetryCommit, categories: [InternalOnly]} + - {code: 377, name: FLEMaxTagLimitExceeded } # Error codes 4000-8999 are reserved. diff --git a/src/mongo/crypto/encryption_fields.idl b/src/mongo/crypto/encryption_fields.idl index 903a1d4f415..f8205b1f156 100644 --- a/src/mongo/crypto/encryption_fields.idl +++ b/src/mongo/crypto/encryption_fields.idl @@ -58,7 +58,7 @@ structs: contention: description: "Contention factor for field, 0 means it has extremely high set number of distinct values" type: exactInt64 - default: 0 + default: 4 unstable: true validator: { gte: 0 } diff --git a/src/mongo/crypto/fle_crypto.cpp b/src/mongo/crypto/fle_crypto.cpp index 38800792351..f55db25f970 100644 --- a/src/mongo/crypto/fle_crypto.cpp +++ b/src/mongo/crypto/fle_crypto.cpp @@ -53,6 +53,7 @@ #include "mongo/base/error_codes.h" #include "mongo/base/status.h" #include "mongo/bson/bson_depth.h" +#include "mongo/bson/bsonmisc.h" #include "mongo/bson/bsonobj.h" #include "mongo/bson/bsonobjbuilder.h" #include "mongo/bson/bsontypes.h" @@ -153,6 +154,8 @@ PrfBlock blockToArray(const SHA256Block& block) { return data; } +} // namespace + PrfBlock PrfBlockfromCDR(ConstDataRange block) { uassert(6373501, "Invalid prf length", block.length() == sizeof(PrfBlock)); @@ -161,6 +164,7 @@ PrfBlock PrfBlockfromCDR(ConstDataRange block) { return ret; } +namespace { ConstDataRange hmacKey(const KeyMaterial& keyMaterial) { static_assert(kHmacKeyOffset + crypto::sym256KeySize <= crypto::kFieldLevelEncryptionKeySize); invariant(crypto::kFieldLevelEncryptionKeySize == keyMaterial->size()); @@ -212,15 +216,18 @@ ConstDataRange binDataToCDR(const BSONElement element) { return ConstDataRange(data, data + len); } -ConstDataRange binDataToCDR(const Value& value) { - uassert(6334103, "Expected binData Value type", value.getType() == BinData); - - auto binData = value.getBinData(); +ConstDataRange binDataToCDR(const BSONBinData binData) { int len = binData.length; const char* data = static_cast<const char*>(binData.data); return ConstDataRange(data, data + len); } +ConstDataRange binDataToCDR(const Value& value) { + uassert(6334103, "Expected binData Value type", value.getType() == BinData); + + return binDataToCDR(value.getBinData()); +} + template <typename T> void toBinData(StringData field, T t, BSONObjBuilder* builder) { BSONObj obj = t.toBSON(); @@ -292,7 +299,7 @@ void toEncryptedBinData(StringData field, std::pair<EncryptedBinDataType, ConstDataRange> fromEncryptedBinData(BSONElement element) { uassert( - 6373502, "Expected binData with subtype Encrypt", element.isBinData(BinDataType::Encrypt)); + 6672414, "Expected binData with subtype Encrypt", element.isBinData(BinDataType::Encrypt)); return fromEncryptedConstDataRange(binDataToCDR(element)); } @@ -965,7 +972,6 @@ void parseAndVerifyInsertUpdatePayload(std::vector<EDCServerPayloadInfo>* pField void collectEDCServerInfo(std::vector<EDCServerPayloadInfo>* pFields, ConstDataRange cdr, - StringData fieldPath) { // TODO - validate field is actually indexed in the schema? @@ -1163,6 +1169,28 @@ uint64_t generateRandomContention(uint64_t cm) { } // namespace +std::pair<EncryptedBinDataType, ConstDataRange> fromEncryptedBinData(const Value& value) { + uassert(6672416, "Expected binData with subtype Encrypt", value.getType() == BinData); + + auto binData = value.getBinData(); + + uassert(6672415, "Expected binData with subtype Encrypt", binData.type == BinDataType::Encrypt); + + return fromEncryptedConstDataRange(binDataToCDR(binData)); +} + +BSONBinData toBSONBinData(const std::vector<uint8_t>& buf) { + return BSONBinData(buf.data(), buf.size(), Encrypt); +} + +std::vector<uint8_t> toEncryptedVector(EncryptedBinDataType dt, const PrfBlock& block) { + std::vector<uint8_t> buf(block.size() + 1); + buf[0] = static_cast<uint8_t>(dt); + + std::copy(block.data(), block.data() + block.size(), buf.data() + 1); + + return buf; +} CollectionsLevel1Token FLELevel1TokenGenerator::generateCollectionsLevel1Token( FLEIndexKey indexKey) { @@ -1364,6 +1392,8 @@ std::pair<BSONType, std::vector<uint8_t>> FLEClientCrypto::decrypt(ConstDataRang return {EOO, vectorFromCDR(pair.second)}; } else if (pair.first == EncryptedBinDataType::kFLE2InsertUpdatePayload) { return {EOO, vectorFromCDR(pair.second)}; + } else if (pair.first == EncryptedBinDataType::kFLE2TransientRaw) { + return {EOO, vectorFromCDR(pair.second)}; } else { uasserted(6373507, "Not supported"); } @@ -1720,6 +1750,8 @@ FLE2FindEqualityPayload FLEClientCrypto::serializeFindPayload(FLEIndexKeyAndId i auto value = ConstDataRange(element.value(), element.value() + element.valuesize()); auto collectionToken = FLELevel1TokenGenerator::generateCollectionsLevel1Token(indexKey.key); + auto serverToken = + FLELevel1TokenGenerator::generateServerDataEncryptionLevel1Token(indexKey.key); auto edcToken = FLECollectionTokenGenerator::generateEDCToken(collectionToken); auto escToken = FLECollectionTokenGenerator::generateESCToken(collectionToken); @@ -1738,6 +1770,7 @@ FLE2FindEqualityPayload FLEClientCrypto::serializeFindPayload(FLEIndexKeyAndId i payload.setEscDerivedToken(escDatakey.toCDR()); payload.setEccDerivedToken(eccDatakey.toCDR()); payload.setMaxCounter(maxContentionFactor); + payload.setServerEncryptionToken(serverToken.toCDR()); return payload; } @@ -2019,7 +2052,8 @@ ESCDerivedFromDataTokenAndContentionFactorToken EDCServerPayloadInfo::getESCToke } void EDCServerCollection::validateEncryptedFieldInfo(BSONObj& obj, - const EncryptedFieldConfig& efc) { + const EncryptedFieldConfig& efc, + bool bypassDocumentValidation) { stdx::unordered_set<std::string> indexedFields; for (auto f : efc.getFields()) { if (f.getQueries().has_value()) { @@ -2036,6 +2070,11 @@ void EDCServerCollection::validateEncryptedFieldInfo(BSONObj& obj, indexedFields.contains(fieldPath.toString())); } }); + + // We should ensure that the user is not manually modifying the safe content array. + uassert(6666200, + str::stream() << "Cannot modify " << kSafeContent << " field in document.", + !obj.hasField(kSafeContent) || bypassDocumentValidation); } @@ -2076,6 +2115,44 @@ PrfBlock EDCServerCollection::generateTag(const FLE2IndexedEqualityEncryptedValu return generateTag(edcTwiceDerived, indexedValue.count); } + +StatusWith<FLE2IndexedEqualityEncryptedValue> EDCServerCollection::decryptAndParse( + ServerDataEncryptionLevel1Token token, ConstDataRange serializedServerValue) { + auto pair = fromEncryptedConstDataRange(serializedServerValue); + uassert(6672412, + "Wrong encrypted field type", + pair.first == EncryptedBinDataType::kFLE2EqualityIndexedValue); + + return FLE2IndexedEqualityEncryptedValue::decryptAndParse(token, pair.second); +} + +StatusWith<FLE2IndexedEqualityEncryptedValue> EDCServerCollection::decryptAndParse( + ConstDataRange token, ConstDataRange serializedServerValue) { + auto serverToken = FLETokenFromCDR<FLETokenType::ServerDataEncryptionLevel1Token>(token); + + return FLE2IndexedEqualityEncryptedValue::decryptAndParse(serverToken, serializedServerValue); +} + +std::vector<EDCDerivedFromDataTokenAndContentionFactorToken> EDCServerCollection::generateEDCTokens( + EDCDerivedFromDataToken token, uint64_t maxContentionFactor) { + std::vector<EDCDerivedFromDataTokenAndContentionFactorToken> tokens; + tokens.reserve(maxContentionFactor); + + for (uint64_t i = 0; i <= maxContentionFactor; ++i) { + tokens.push_back(FLEDerivedFromDataTokenAndContentionFactorTokenGenerator:: + generateEDCDerivedFromDataTokenAndContentionFactorToken(token, i)); + } + + return tokens; +} + +std::vector<EDCDerivedFromDataTokenAndContentionFactorToken> EDCServerCollection::generateEDCTokens( + ConstDataRange rawToken, uint64_t maxContentionFactor) { + auto token = FLETokenFromCDR<FLETokenType::EDCDerivedFromDataToken>(rawToken); + + return generateEDCTokens(token, maxContentionFactor); +} + BSONObj EDCServerCollection::finalizeForInsert( const BSONObj& doc, const std::vector<EDCServerPayloadInfo>& serverPayload) { std::vector<TagInfo> tags; @@ -2305,6 +2382,7 @@ EncryptedFieldConfig EncryptionInformationHelpers::getAndValidateSchema( return efc; } + std::pair<EncryptedBinDataType, ConstDataRange> fromEncryptedConstDataRange(ConstDataRange cdr) { ConstDataRangeCursor cdrc(cdr); @@ -2377,6 +2455,12 @@ ParsedFindPayload::ParsedFindPayload(ConstDataRange cdr) { escToken = FLETokenFromCDR<FLETokenType::ESCDerivedFromDataToken>(payload.getEscDerivedToken()); eccToken = FLETokenFromCDR<FLETokenType::ECCDerivedFromDataToken>(payload.getEccDerivedToken()); edcToken = FLETokenFromCDR<FLETokenType::EDCDerivedFromDataToken>(payload.getEdcDerivedToken()); + + if (payload.getServerEncryptionToken().has_value()) { + serverToken = FLETokenFromCDR<FLETokenType::ServerDataEncryptionLevel1Token>( + payload.getServerEncryptionToken().value()); + } + maxCounter = payload.getMaxCounter(); } diff --git a/src/mongo/crypto/fle_crypto.h b/src/mongo/crypto/fle_crypto.h index 5d8285f5790..5feac8ca2d3 100644 --- a/src/mongo/crypto/fle_crypto.h +++ b/src/mongo/crypto/fle_crypto.h @@ -41,6 +41,7 @@ #include "mongo/base/status_with.h" #include "mongo/base/string_data.h" #include "mongo/bson/bsonelement.h" +#include "mongo/bson/bsonmisc.h" #include "mongo/bson/bsonobj.h" #include "mongo/bson/bsontypes.h" #include "mongo/crypto/aead_encryption.h" @@ -1009,13 +1010,21 @@ public: /** * Validate that payload is compatible with schema */ - static void validateEncryptedFieldInfo(BSONObj& obj, const EncryptedFieldConfig& efc); + static void validateEncryptedFieldInfo(BSONObj& obj, + const EncryptedFieldConfig& efc, + bool bypassDocumentValidation); /** * Get information about all FLE2InsertUpdatePayload payloads */ static std::vector<EDCServerPayloadInfo> getEncryptedFieldInfo(BSONObj& obj); + static StatusWith<FLE2IndexedEqualityEncryptedValue> decryptAndParse( + ServerDataEncryptionLevel1Token token, ConstDataRange serializedServerValue); + + static StatusWith<FLE2IndexedEqualityEncryptedValue> decryptAndParse( + ConstDataRange token, ConstDataRange serializedServerValue); + /** * Generate a search tag * @@ -1026,6 +1035,14 @@ public: static PrfBlock generateTag(const FLE2IndexedEqualityEncryptedValue& indexedValue); /** + * Generate all the EDC tokens + */ + static std::vector<EDCDerivedFromDataTokenAndContentionFactorToken> generateEDCTokens( + EDCDerivedFromDataToken token, uint64_t maxContentionFactor); + static std::vector<EDCDerivedFromDataTokenAndContentionFactorToken> generateEDCTokens( + ConstDataRange rawToken, uint64_t maxContentionFactor); + + /** * Consumes a payload from a MongoDB client for insert. * * Converts FLE2InsertUpdatePayload to a final insert payload and updates __safeContent__ with @@ -1163,6 +1180,7 @@ struct ParsedFindPayload { ESCDerivedFromDataToken escToken; ECCDerivedFromDataToken eccToken; EDCDerivedFromDataToken edcToken; + boost::optional<ServerDataEncryptionLevel1Token> serverToken; boost::optional<std::int64_t> maxCounter; explicit ParsedFindPayload(BSONElement fleFindPayload); @@ -1170,4 +1188,15 @@ struct ParsedFindPayload { explicit ParsedFindPayload(ConstDataRange cdr); }; +/** + * Utility functions manipulating buffers + */ +PrfBlock PrfBlockfromCDR(ConstDataRange block); + +std::vector<uint8_t> toEncryptedVector(EncryptedBinDataType dt, const PrfBlock& block); + +BSONBinData toBSONBinData(const std::vector<uint8_t>& buf); + +std::pair<EncryptedBinDataType, ConstDataRange> fromEncryptedBinData(const Value& value); + } // namespace mongo diff --git a/src/mongo/crypto/fle_crypto_test.cpp b/src/mongo/crypto/fle_crypto_test.cpp index 75c1976c097..4c4355ebb9f 100644 --- a/src/mongo/crypto/fle_crypto_test.cpp +++ b/src/mongo/crypto/fle_crypto_test.cpp @@ -33,6 +33,7 @@ #include "mongo/crypto/fle_crypto.h" #include <algorithm> +#include <cstdint> #include <iostream> #include <limits> #include <stack> @@ -696,7 +697,8 @@ std::vector<char> generatePlaceholder( BSONElement value, Operation operation, mongo::Fle2AlgorithmInt algorithm = mongo::Fle2AlgorithmInt::kEquality, - boost::optional<UUID> key = boost::none) { + boost::optional<UUID> key = boost::none, + uint64_t contention = 0) { FLE2EncryptionPlaceholder ep; if (operation == Operation::kFind) { @@ -709,7 +711,7 @@ std::vector<char> generatePlaceholder( ep.setUserKeyId(userKeyId); ep.setIndexKeyId(key.value_or(indexKeyId)); ep.setValue(value); - ep.setMaxContentionCounter(0); + ep.setMaxContentionCounter(contention); BSONObj obj = ep.toBSON(); @@ -726,7 +728,7 @@ BSONObj encryptDocument(BSONObj obj, auto result = FLEClientCrypto::transformPlaceholders(obj, keyVault); if (nullptr != efc) { - EDCServerCollection::validateEncryptedFieldInfo(result, *efc); + EDCServerCollection::validateEncryptedFieldInfo(result, *efc, false); } // Start Server Side @@ -832,6 +834,41 @@ void roundTripMultiencrypted(BSONObj doc1, assertPayload(finalDoc["encrypted2"], operation2); } +// Used to generate the test data for the ExpressionFLETest in expression_test.cpp +TEST(FLE_EDC, PrintTest) { + auto doc = BSON("value" << 1); + auto element = doc.firstElement(); + + TestKeyVault keyVault; + + auto inputDoc = BSON("plainText" + << "sample" + << "encrypted" << element); + + { + auto buf = generatePlaceholder(element, Operation::kInsert, Fle2AlgorithmInt::kEquality); + BSONObjBuilder builder; + builder.append("plainText", "sample"); + builder.appendBinData("encrypted", buf.size(), BinDataType::Encrypt, buf.data()); + + auto finalDoc = encryptDocument(builder.obj(), &keyVault); + + std::cout << finalDoc.jsonString() << std::endl; + } + + { + auto buf = generatePlaceholder( + element, Operation::kInsert, Fle2AlgorithmInt::kEquality, boost::none, 50); + BSONObjBuilder builder; + builder.append("plainText", "sample"); + builder.appendBinData("encrypted", buf.size(), BinDataType::Encrypt, buf.data()); + + auto finalDoc = encryptDocument(builder.obj(), &keyVault); + + std::cout << finalDoc.jsonString() << std::endl; + } +} + TEST(FLE_EDC, Allowed_Types) { const std::vector<std::pair<BSONObj, BSONType>> universallyAllowedObjects{ {BSON("sample" @@ -1928,4 +1965,25 @@ TEST(CompactionHelpersTest, countDeletedTest) { ASSERT_EQ(CompactionHelpers::countDeleted(input), 20); } +TEST(EDCServerCollectionTest, GenerateEDCTokens) { + + auto doc = BSON("sample" << 123456); + auto element = doc.firstElement(); + + auto value = ConstDataRange(element.value(), element.value() + element.valuesize()); + + auto collectionToken = FLELevel1TokenGenerator::generateCollectionsLevel1Token(getIndexKey()); + auto edcToken = FLECollectionTokenGenerator::generateEDCToken(collectionToken); + + EDCDerivedFromDataToken edcDatakey = + FLEDerivedFromDataTokenGenerator::generateEDCDerivedFromDataToken(edcToken, value); + + + ASSERT_EQ(EDCServerCollection::generateEDCTokens(edcDatakey, 0).size(), 1); + ASSERT_EQ(EDCServerCollection::generateEDCTokens(edcDatakey, 1).size(), 2); + ASSERT_EQ(EDCServerCollection::generateEDCTokens(edcDatakey, 2).size(), 3); + ASSERT_EQ(EDCServerCollection::generateEDCTokens(edcDatakey, 3).size(), 4); +} + + } // namespace mongo diff --git a/src/mongo/crypto/fle_field_schema.idl b/src/mongo/crypto/fle_field_schema.idl index 030fa36ef3f..d9e2b54b890 100644 --- a/src/mongo/crypto/fle_field_schema.idl +++ b/src/mongo/crypto/fle_field_schema.idl @@ -51,6 +51,10 @@ enums: kFLE2UnindexedEncryptedValue : 6 # see FLE2IndexedEqualityEncryptedValue kFLE2EqualityIndexedValue : 7 + # Transient encrypted data in query rewrites, not persisted + # same as BinDataGeneral but redacted + kFLE2TransientRaw : 8 + FleVersion: description: "The version / type of field-level encryption in use." type: int diff --git a/src/mongo/crypto/fle_tags.cpp b/src/mongo/crypto/fle_tags.cpp index a0de37b2f42..4737ff13144 100644 --- a/src/mongo/crypto/fle_tags.cpp +++ b/src/mongo/crypto/fle_tags.cpp @@ -56,8 +56,11 @@ void verifyTagsWillFit(size_t tagCount, size_t memoryLimit) { constexpr size_t largestElementSize = arrayElementSize(std::numeric_limits<size_t>::digits10); constexpr size_t ridiculousNumberOfTags = std::numeric_limits<size_t>::max() / largestElementSize; - uassert(6653300, "Encrypted rewrite too many tags", tagCount < ridiculousNumberOfTags); - uassert(6401800, + + uassert(ErrorCodes::FLEMaxTagLimitExceeded, + "Encrypted rewrite too many tags", + tagCount < ridiculousNumberOfTags); + uassert(ErrorCodes::FLEMaxTagLimitExceeded, "Encrypted rewrite memory limit exceeded", sizeArrayElementsMemory(tagCount) <= memoryLimit); } diff --git a/src/mongo/db/active_index_builds.cpp b/src/mongo/db/active_index_builds.cpp index 017ca49ecc8..d6a7f5afabf 100644 --- a/src/mongo/db/active_index_builds.cpp +++ b/src/mongo/db/active_index_builds.cpp @@ -28,6 +28,8 @@ */ #define MONGO_LOGV2_DEFAULT_COMPONENT ::mongo::logv2::LogComponent::kStorage +#include <fmt/format.h> + #include "mongo/db/active_index_builds.h" #include "mongo/db/catalog/index_builds_manager.h" #include "mongo/logv2/log.h" @@ -63,10 +65,14 @@ void ActiveIndexBuilds::waitForAllIndexBuildsToStopForShutdown(OperationContext* void ActiveIndexBuilds::assertNoIndexBuildInProgress() const { stdx::unique_lock<Latch> lk(_mutex); - uassert(ErrorCodes::BackgroundOperationInProgressForDatabase, - str::stream() << "cannot perform operation: there are currently " - << _allIndexBuilds.size() << " index builds running.", - _allIndexBuilds.size() == 0); + if (!_allIndexBuilds.empty()) { + auto firstIndexBuild = _allIndexBuilds.cbegin()->second; + uasserted(ErrorCodes::BackgroundOperationInProgressForDatabase, + fmt::format("cannot perform operation: there are currently {} index builds " + "running. Found index build: {}", + _allIndexBuilds.size(), + firstIndexBuild->buildUUID.toString())); + } } void ActiveIndexBuilds::waitUntilAnIndexBuildFinishes(OperationContext* opCtx) { diff --git a/src/mongo/db/catalog/coll_mod.cpp b/src/mongo/db/catalog/coll_mod.cpp index ca5f29ae55e..8af0ba8efc9 100644 --- a/src/mongo/db/catalog/coll_mod.cpp +++ b/src/mongo/db/catalog/coll_mod.cpp @@ -175,13 +175,11 @@ StatusWith<std::pair<ParsedCollModRequest, BSONObj>> parseCollModRequest(Operati } if (auto& cappedSize = cmr.getCappedSize()) { - static constexpr long long minCappedSize = 4096; auto swCappedSize = CollectionOptions::checkAndAdjustCappedSize(*cappedSize); if (!swCappedSize.isOK()) { return swCappedSize.getStatus(); } - parsed.cappedSize = - (swCappedSize.getValue() < minCappedSize) ? minCappedSize : swCappedSize.getValue(); + parsed.cappedSize = swCappedSize.getValue(); oplogEntryBuilder.append(CollMod::kCappedSizeFieldName, *cappedSize); } if (auto& cappedMax = cmr.getCappedMax()) { diff --git a/src/mongo/db/catalog/collection_impl.cpp b/src/mongo/db/catalog/collection_impl.cpp index 29e343e44bc..b79c8c78914 100644 --- a/src/mongo/db/catalog/collection_impl.cpp +++ b/src/mongo/db/catalog/collection_impl.cpp @@ -38,6 +38,7 @@ #include "mongo/bson/ordering.h" #include "mongo/bson/simple_bsonelement_comparator.h" #include "mongo/bson/simple_bsonobj_comparator.h" +#include "mongo/crypto/fle_crypto.h" #include "mongo/db/auth/security_token.h" #include "mongo/db/catalog/collection_catalog.h" #include "mongo/db/catalog/collection_options.h" @@ -816,7 +817,6 @@ Status CollectionImpl::insertDocumentsForOplog(OperationContext* opCtx, return status; } - Status CollectionImpl::insertDocuments(OperationContext* opCtx, const std::vector<InsertStatement>::const_iterator begin, const std::vector<InsertStatement>::const_iterator end, @@ -840,8 +840,20 @@ Status CollectionImpl::insertDocuments(OperationContext* opCtx, } auto status = _checkValidationAndParseResult(opCtx, it->doc); - if (!status.isOK()) + if (!status.isOK()) { return status; + } + + auto& validationSettings = DocumentValidationSettings::get(opCtx); + + if (getCollectionOptions().encryptedFieldConfig && + !validationSettings.isSchemaValidationDisabled() && + !validationSettings.isSafeContentValidationDisabled() && + it->doc.hasField(kSafeContent)) { + return Status(ErrorCodes::BadValue, + str::stream() + << "Cannot insert a document with field name " << kSafeContent); + } } const SnapshotId sid = opCtx->recoveryUnit()->getSnapshotId(); @@ -1342,6 +1354,17 @@ void CollectionImpl::deleteDocument(OperationContext* opCtx, } } +bool compareSafeContentElem(const BSONObj& oldDoc, const BSONObj& newDoc) { + if (newDoc.hasField(kSafeContent) != oldDoc.hasField(kSafeContent)) { + return false; + } + if (!newDoc.hasField(kSafeContent)) { + return true; + } + + return newDoc.getField(kSafeContent).binaryEqual(oldDoc.getField(kSafeContent)); +} + RecordId CollectionImpl::updateDocument(OperationContext* opCtx, RecordId oldLocation, const Snapshotted<BSONObj>& oldDoc, @@ -1366,6 +1389,17 @@ RecordId CollectionImpl::updateDocument(OperationContext* opCtx, } } + auto& validationSettings = DocumentValidationSettings::get(opCtx); + if (getCollectionOptions().encryptedFieldConfig && + !validationSettings.isSchemaValidationDisabled() && + !validationSettings.isSafeContentValidationDisabled()) { + + uassert(ErrorCodes::BadValue, + str::stream() << "New document and old document both need to have " << kSafeContent + << " field.", + compareSafeContentElem(oldDoc.value(), newDoc)); + } + dassert(opCtx->lockState()->isCollectionLockedForMode(ns(), MODE_IX)); invariant(oldDoc.snapshotId() == opCtx->recoveryUnit()->getSnapshotId()); invariant(newDoc.isOwned()); diff --git a/src/mongo/db/catalog/document_validation.h b/src/mongo/db/catalog/document_validation.h index 875f255c565..47db304d79d 100644 --- a/src/mongo/db/catalog/document_validation.h +++ b/src/mongo/db/catalog/document_validation.h @@ -52,7 +52,7 @@ class DocumentValidationSettings { public: enum flag : std::uint8_t { /* - * Enables document validation (both schema and internal). + * Enables document validation (schema, internal, and safeContent). */ kEnableValidation = 0x00, /* @@ -67,6 +67,12 @@ public: * doesn't comply with internal validation rules. */ kDisableInternalValidation = 0x02, + /* + * If set, modifications to the safeContent array are allowed. This flag is only + * enabled when bypass document validation is enabled or if crudProcessed is true + * in the query. + */ + kDisableSafeContentValidation = 0x04, }; using Flags = std::uint8_t; @@ -92,6 +98,10 @@ public: return _flags & kDisableInternalValidation; } + bool isSafeContentValidationDisabled() const { + return _flags & kDisableSafeContentValidation; + } + bool isDocumentValidationEnabled() const { return _flags == kEnableValidation; } @@ -134,11 +144,29 @@ class DisableDocumentSchemaValidationIfTrue { public: DisableDocumentSchemaValidationIfTrue(OperationContext* opCtx, bool shouldDisableSchemaValidation) { - if (shouldDisableSchemaValidation) - _documentSchemaValidationDisabler.emplace(opCtx); + if (shouldDisableSchemaValidation) { + _documentSchemaValidationDisabler.emplace( + opCtx, DocumentValidationSettings::kDisableSchemaValidation); + } + } + +private: + boost::optional<DisableDocumentValidation> _documentSchemaValidationDisabler; +}; + +class DisableSafeContentValidationIfTrue { +public: + DisableSafeContentValidationIfTrue(OperationContext* opCtx, + bool shouldDisableSchemaValidation, + bool encryptionInformationCrudProcessed) { + if (shouldDisableSchemaValidation || encryptionInformationCrudProcessed) { + _documentSchemaValidationDisabler.emplace( + opCtx, DocumentValidationSettings::kDisableSafeContentValidation); + } } private: boost::optional<DisableDocumentValidation> _documentSchemaValidationDisabler; }; + } // namespace mongo diff --git a/src/mongo/db/commands/find_and_modify.cpp b/src/mongo/db/commands/find_and_modify.cpp index 00203a4c485..abbc0d834fd 100644 --- a/src/mongo/db/commands/find_and_modify.cpp +++ b/src/mongo/db/commands/find_and_modify.cpp @@ -636,10 +636,15 @@ write_ops::FindAndModifyCommandReply CmdFindAndModify::Invocation::typedRun( // Collect metrics. CmdFindAndModify::collectMetrics(req); - boost::optional<DisableDocumentValidation> maybeDisableValidation; - if (req.getBypassDocumentValidation().value_or(false)) { - maybeDisableValidation.emplace(opCtx); - } + auto disableDocumentValidation = req.getBypassDocumentValidation().value_or(false); + auto fleCrudProcessed = + write_ops_exec::getFleCrudProcessed(opCtx, req.getEncryptionInformation()); + + DisableDocumentSchemaValidationIfTrue docSchemaValidationDisabler(opCtx, + disableDocumentValidation); + + DisableSafeContentValidationIfTrue safeContentValidationDisabler( + opCtx, disableDocumentValidation, fleCrudProcessed); const auto inTransaction = opCtx->inMultiDocumentTransaction(); uassert(50781, diff --git a/src/mongo/db/commands/fle_compact_test.cpp b/src/mongo/db/commands/fle_compact_test.cpp index 18c52f548ef..26153aadcc8 100644 --- a/src/mongo/db/commands/fle_compact_test.cpp +++ b/src/mongo/db/commands/fle_compact_test.cpp @@ -395,8 +395,13 @@ void FleCompactTest::doSingleInsert(int id, BSONObj encryptedFieldsObj) { auto efc = generateEncryptedFieldConfig(encryptedFieldsObj.getFieldNames<std::set<std::string>>()); - uassertStatusOK(processInsert( - _queryImpl.get(), _namespaces.edcNss, serverPayload, efc, kUninitializedTxnNumber, result)); + uassertStatusOK(processInsert(_queryImpl.get(), + _namespaces.edcNss, + serverPayload, + efc, + kUninitializedTxnNumber, + result, + false)); } void FleCompactTest::doSingleDelete(int id, BSONObj encryptedFieldsObj) { diff --git a/src/mongo/db/commands/set_feature_compatibility_version_command.cpp b/src/mongo/db/commands/set_feature_compatibility_version_command.cpp index c63a2978e11..73933d1abe2 100644 --- a/src/mongo/db/commands/set_feature_compatibility_version_command.cpp +++ b/src/mongo/db/commands/set_feature_compatibility_version_command.cpp @@ -66,7 +66,6 @@ #include "mongo/db/repl/tenant_migration_recipient_service.h" #include "mongo/db/s/active_migrations_registry.h" #include "mongo/db/s/balancer/balancer.h" -#include "mongo/db/s/collection_sharding_state.h" #include "mongo/db/s/config/configsvr_coordinator_service.h" #include "mongo/db/s/config/sharding_catalog_manager.h" #include "mongo/db/s/migration_coordinator_document_gen.h" @@ -75,11 +74,11 @@ #include "mongo/db/s/resharding/coordinator_document_gen.h" #include "mongo/db/s/resharding/resharding_coordinator_service.h" #include "mongo/db/s/resharding/resharding_donor_recipient_common.h" -#include "mongo/db/s/shard_metadata_util.h" +#include "mongo/db/s/shard_filtering_metadata_refresh.h" #include "mongo/db/s/sharding_ddl_coordinator_service.h" +#include "mongo/db/s/sharding_state.h" #include "mongo/db/s/sharding_util.h" #include "mongo/db/s/transaction_coordinator_service.h" -#include "mongo/db/s/type_shard_collection.h" #include "mongo/db/server_feature_flags_gen.h" #include "mongo/db/server_options.h" #include "mongo/db/session_catalog.h" @@ -91,6 +90,7 @@ #include "mongo/logv2/log.h" #include "mongo/rpc/get_status_from_command_result.h" #include "mongo/s/catalog/sharding_catalog_client.h" +#include "mongo/s/catalog_cache_loader.h" #include "mongo/s/pm2423_feature_flags_gen.h" #include "mongo/s/pm2583_feature_flags_gen.h" #include "mongo/s/refine_collection_shard_key_coordinator_feature_flags_gen.h" @@ -225,20 +225,6 @@ void uassertStatusOKIgnoreNSNotFound(Status status) { uassertStatusOK(status); } -void clearFilteringMetadataOnSecondaries(OperationContext* opCtx, const NamespaceString& collName) { - Status signalStatus = shardmetadatautil::updateShardCollectionsEntry( - opCtx, - BSON(ShardCollectionType::kNssFieldName << collName.ns()), - BSON("$inc" << BSON(ShardCollectionType::kEnterCriticalSectionCounterFieldName << 1)), - false /*upsert*/); - - uassertStatusOKWithContext( - signalStatus, - str::stream() - << "Failed to persist signal to clear the filtering metadata on secondaries for nss " - << collName.ns()); -} - /** * Sets the minimum allowed feature compatibility version for the cluster. The cluster should not * use any new features introduced in binary versions that are newer than the feature compatibility @@ -564,26 +550,18 @@ public: } } - if (requestedVersion == multiversion::FeatureCompatibilityVersion::kVersion_5_3 || - requestedVersion == multiversion::FeatureCompatibilityVersion::kVersion_6_0) { + if (requestedVersion == multiversion::FeatureCompatibilityVersion::kVersion_6_0 && + ShardingState::get(opCtx)->enabled()) { const auto colls = CollectionShardingState::getCollectionNames(opCtx); for (const auto& collName : colls) { - try { - if (!collName.isSystemDotViews()) { - { - AutoGetCollection coll(opCtx, collName, MODE_IX); - CollectionShardingState::get(opCtx, collName) - ->clearFilteringMetadata_DoNotUseIt(opCtx); - } - clearFilteringMetadataOnSecondaries(opCtx, collName); - } - } catch (const ExceptionFor<ErrorCodes::CommandNotSupportedOnView>&) { - // Nothing to do since collName is a view - } + onShardVersionMismatch(opCtx, collName, boost::none); + CatalogCacheLoader::get(opCtx).waitForCollectionFlush(opCtx, collName); } - // Wait until the signals to clear the filtering metadata on secondary nodes are - // majority committed. + repl::ReplClientInfo::forClient(opCtx->getClient()) + .setLastOpToSystemLastOpTime(opCtx); + + // Wait until the changes on config.cache.* are majority committed. WriteConcernResult ignoreResult; auto latestOpTime = repl::ReplClientInfo::forClient(opCtx->getClient()).getLastOp(); uassertStatusOK(waitForWriteConcern(opCtx, diff --git a/src/mongo/db/commands/write_commands.cpp b/src/mongo/db/commands/write_commands.cpp index 9e6d189b4a3..0254baca47d 100644 --- a/src/mongo/db/commands/write_commands.cpp +++ b/src/mongo/db/commands/write_commands.cpp @@ -529,7 +529,8 @@ public: write_ops::InsertCommandReply typedRun(OperationContext* opCtx) final try { transactionChecks(opCtx, ns()); - if (request().getEncryptionInformation().has_value()) { + if (request().getEncryptionInformation().has_value() && + !request().getEncryptionInformation()->getCrudProcessed()) { write_ops::InsertCommandReply insertReply; auto batch = processFLEInsert(opCtx, request(), &insertReply); if (batch == FLEBatchResult::kProcessed) { @@ -1456,7 +1457,8 @@ public: write_ops::UpdateCommandReply updateReply; OperationSource source = OperationSource::kStandard; - if (request().getEncryptionInformation().has_value()) { + if (request().getEncryptionInformation().has_value() && + !request().getEncryptionInformation().get().getCrudProcessed()) { return processFLEUpdate(opCtx, request()); } diff --git a/src/mongo/db/dbmessage.h b/src/mongo/db/dbmessage.h index 95e0cfc6f7f..e280bf296be 100644 --- a/src/mongo/db/dbmessage.h +++ b/src/mongo/db/dbmessage.h @@ -227,7 +227,7 @@ public: * Indicates whether this message is expected to have a ns. */ bool messageShouldHaveNs() const { - return (_msg.operation() >= dbUpdate) & (_msg.operation() <= dbDelete); + return static_cast<int>(_msg.operation() >= dbUpdate) & (_msg.operation() <= dbDelete); } /** diff --git a/src/mongo/db/exec/add_fields_projection_executor.cpp b/src/mongo/db/exec/add_fields_projection_executor.cpp index 592074b4834..a0fd7f08580 100644 --- a/src/mongo/db/exec/add_fields_projection_executor.cpp +++ b/src/mongo/db/exec/add_fields_projection_executor.cpp @@ -92,38 +92,6 @@ private: // The original object. Used to generate more helpful error messages. const BSONObj& _rawObj; - // Custom comparator that orders fieldpath strings by path prefix first, then by field. - struct PathPrefixComparator { - static constexpr char dot = '.'; - - // Returns true if the lhs value should sort before the rhs, false otherwise. - bool operator()(const std::string& lhs, const std::string& rhs) const { - for (size_t pos = 0, len = std::min(lhs.size(), rhs.size()); pos < len; ++pos) { - auto &lchar = lhs[pos], &rchar = rhs[pos]; - if (lchar == rchar) { - continue; - } - - // Consider the path delimiter '.' as being less than all other characters, so that - // paths sort directly before any paths they prefix and directly after any paths - // which prefix them. - if (lchar == dot) { - return true; - } else if (rchar == dot) { - return false; - } - - // Otherwise, default to normal character comparison. - return lchar < rchar; - } - - // If we get here, then we have reached the end of lhs and/or rhs and all of their path - // segments up to this point match. If lhs is shorter than rhs, then lhs prefixes rhs - // and should sort before it. - return lhs.size() < rhs.size(); - } - }; - // Tracks which paths we've seen to ensure no two paths conflict with each other. std::set<std::string, PathPrefixComparator> _seenPaths; }; diff --git a/src/mongo/db/fle_crud.cpp b/src/mongo/db/fle_crud.cpp index a847304b55f..58136f31c44 100644 --- a/src/mongo/db/fle_crud.cpp +++ b/src/mongo/db/fle_crud.cpp @@ -190,16 +190,20 @@ std::pair<FLEBatchResult, write_ops::InsertCommandReply> processInsert( auto edcNss = insertRequest.getNamespace(); auto ei = insertRequest.getEncryptionInformation().get(); + bool bypassDocumentValidation = + insertRequest.getWriteCommandRequestBase().getBypassDocumentValidation(); + auto efc = EncryptionInformationHelpers::getAndValidateSchema(edcNss, ei); auto documents = insertRequest.getDocuments(); // TODO - how to check if a document will be too large??? + uassert(6371202, "Only single insert batches are supported in Queryable Encryption", documents.size() == 1); auto document = documents[0]; - EDCServerCollection::validateEncryptedFieldInfo(document, efc); + EDCServerCollection::validateEncryptedFieldInfo(document, efc, bypassDocumentValidation); auto serverPayload = std::make_shared<std::vector<EDCServerPayloadInfo>>( EDCServerCollection::getEncryptedFieldInfo(document)); @@ -223,8 +227,8 @@ std::pair<FLEBatchResult, write_ops::InsertCommandReply> processInsert( auto swResult = trun->runNoThrow( opCtx, - [sharedInsertBlock, reply, ownedDocument](const txn_api::TransactionClient& txnClient, - ExecutorPtr txnExec) { + [sharedInsertBlock, reply, ownedDocument, bypassDocumentValidation]( + const txn_api::TransactionClient& txnClient, ExecutorPtr txnExec) { FLEQueryInterfaceImpl queryImpl(txnClient, getGlobalServiceContext()); auto [edcNss2, efc2, serverPayload2, stmtId2] = *sharedInsertBlock.get(); @@ -234,8 +238,13 @@ std::pair<FLEBatchResult, write_ops::InsertCommandReply> processInsert( fleCrudHangPreInsert.pauseWhileSet(); } - *reply = uassertStatusOK(processInsert( - &queryImpl, edcNss2, *serverPayload2.get(), efc2, stmtId2, ownedDocument)); + *reply = uassertStatusOK(processInsert(&queryImpl, + edcNss2, + *serverPayload2.get(), + efc2, + stmtId2, + ownedDocument, + bypassDocumentValidation)); if (MONGO_unlikely(fleCrudHangInsert.shouldFail())) { LOGV2(6371903, "Hanging due to fleCrudHangInsert fail point"); @@ -441,7 +450,8 @@ void processFieldsForInsert(FLEQueryInterface* queryImpl, const NamespaceString& edcNss, std::vector<EDCServerPayloadInfo>& serverPayload, const EncryptedFieldConfig& efc, - int32_t* pStmtId) { + int32_t* pStmtId, + bool bypassDocumentValidation) { NamespaceString nssEsc(edcNss.db(), efc.getEscCollection().get()); @@ -509,7 +519,8 @@ void processFieldsForInsert(FLEQueryInterface* queryImpl, ECOCCollection::generateDocument(payload.fieldPathName, payload.payload.getEncryptedTokens()), pStmtId, - false)); + false, + bypassDocumentValidation)); checkWriteErrors(ecocInsertReply); } } @@ -719,9 +730,11 @@ StatusWith<write_ops::InsertCommandReply> processInsert( std::vector<EDCServerPayloadInfo>& serverPayload, const EncryptedFieldConfig& efc, int32_t stmtId, - BSONObj document) { + BSONObj document, + bool bypassDocumentValidation) { - processFieldsForInsert(queryImpl, edcNss, serverPayload, efc, &stmtId); + processFieldsForInsert( + queryImpl, edcNss, serverPayload, efc, &stmtId, bypassDocumentValidation); auto finalDoc = EDCServerCollection::finalizeForInsert(document, serverPayload); @@ -792,6 +805,9 @@ write_ops::UpdateCommandReply processUpdate(FLEQueryInterface* queryImpl, auto tokenMap = EncryptionInformationHelpers::getDeleteTokens(edcNss, ei); const auto updateOpEntry = updateRequest.getUpdates()[0]; + auto bypassDocumentValidation = + updateRequest.getWriteCommandRequestBase().getBypassDocumentValidation(); + const auto updateModification = updateOpEntry.getU(); int32_t stmtId = getStmtIdForWriteAt(updateRequest, 0); @@ -799,16 +815,26 @@ write_ops::UpdateCommandReply processUpdate(FLEQueryInterface* queryImpl, // Step 1 ---- std::vector<EDCServerPayloadInfo> serverPayload; auto newUpdateOpEntry = updateRequest.getUpdates()[0]; - newUpdateOpEntry.setQ(fle::rewriteEncryptedFilterInsideTxn( - queryImpl, updateRequest.getDbName(), efc, expCtx, newUpdateOpEntry.getQ())); + + auto highCardinalityModeAllowed = newUpdateOpEntry.getUpsert() + ? fle::HighCardinalityModeAllowed::kDisallow + : fle::HighCardinalityModeAllowed::kAllow; + + newUpdateOpEntry.setQ(fle::rewriteEncryptedFilterInsideTxn(queryImpl, + updateRequest.getDbName(), + efc, + expCtx, + newUpdateOpEntry.getQ(), + highCardinalityModeAllowed)); if (updateModification.type() == write_ops::UpdateModification::Type::kModifier) { auto updateModifier = updateModification.getUpdateModifier(); auto setObject = updateModifier.getObjectField("$set"); - EDCServerCollection::validateEncryptedFieldInfo(setObject, efc); + EDCServerCollection::validateEncryptedFieldInfo(setObject, efc, bypassDocumentValidation); serverPayload = EDCServerCollection::getEncryptedFieldInfo(updateModifier); - processFieldsForInsert(queryImpl, edcNss, serverPayload, efc, &stmtId); + processFieldsForInsert( + queryImpl, edcNss, serverPayload, efc, &stmtId, bypassDocumentValidation); // Step 2 ---- auto pushUpdate = EDCServerCollection::finalizeForUpdate(updateModifier, serverPayload); @@ -817,10 +843,12 @@ write_ops::UpdateCommandReply processUpdate(FLEQueryInterface* queryImpl, pushUpdate, write_ops::UpdateModification::ClassicTag(), false)); } else { auto replacementDocument = updateModification.getUpdateReplacement(); - EDCServerCollection::validateEncryptedFieldInfo(replacementDocument, efc); + EDCServerCollection::validateEncryptedFieldInfo( + replacementDocument, efc, bypassDocumentValidation); serverPayload = EDCServerCollection::getEncryptedFieldInfo(replacementDocument); - processFieldsForInsert(queryImpl, edcNss, serverPayload, efc, &stmtId); + processFieldsForInsert( + queryImpl, edcNss, serverPayload, efc, &stmtId, bypassDocumentValidation); // Step 2 ---- auto safeContentReplace = @@ -835,6 +863,8 @@ write_ops::UpdateCommandReply processUpdate(FLEQueryInterface* queryImpl, newUpdateRequest.setUpdates({newUpdateOpEntry}); newUpdateRequest.getWriteCommandRequestBase().setStmtIds(boost::none); newUpdateRequest.getWriteCommandRequestBase().setStmtId(stmtId); + newUpdateRequest.getWriteCommandRequestBase().setBypassDocumentValidation( + bypassDocumentValidation); ++stmtId; auto [updateReply, originalDocument] = @@ -892,6 +922,10 @@ FLEBatchResult processFLEBatch(OperationContext* opCtx, BatchedCommandResponse* response, boost::optional<OID> targetEpoch) { + if (request.getWriteCommandRequestBase().getEncryptionInformation()->getCrudProcessed()) { + return FLEBatchResult::kNotProcessed; + } + // TODO (SERVER-65077): Remove FCV check once 6.0 is released uassert(6371209, "Queryable Encryption is only supported when FCV supports 6.0", @@ -970,19 +1004,25 @@ std::unique_ptr<BatchedCommandRequest> processFLEBatchExplain( request.getNS(), deleteRequest.getEncryptionInformation().get(), newDeleteOp.getQ(), - &getTransactionWithRetriesForMongoS)); + &getTransactionWithRetriesForMongoS, + fle::HighCardinalityModeAllowed::kAllow)); deleteRequest.setDeletes({newDeleteOp}); deleteRequest.getWriteCommandRequestBase().setEncryptionInformation(boost::none); return std::make_unique<BatchedCommandRequest>(deleteRequest); } else if (request.getBatchType() == BatchedCommandRequest::BatchType_Update) { auto updateRequest = request.getUpdateRequest(); auto newUpdateOp = updateRequest.getUpdates()[0]; + auto highCardinalityModeAllowed = newUpdateOp.getUpsert() + ? fle::HighCardinalityModeAllowed::kDisallow + : fle::HighCardinalityModeAllowed::kAllow; + newUpdateOp.setQ(fle::rewriteQuery(opCtx, getExpCtx(newUpdateOp), request.getNS(), updateRequest.getEncryptionInformation().get(), newUpdateOp.getQ(), - &getTransactionWithRetriesForMongoS)); + &getTransactionWithRetriesForMongoS, + highCardinalityModeAllowed)); updateRequest.setUpdates({newUpdateOp}); updateRequest.getWriteCommandRequestBase().setEncryptionInformation(boost::none); return std::make_unique<BatchedCommandRequest>(updateRequest); @@ -1005,10 +1045,22 @@ write_ops::FindAndModifyCommandReply processFindAndModify( auto newFindAndModifyRequest = findAndModifyRequest; + const auto bypassDocumentValidation = + findAndModifyRequest.getBypassDocumentValidation().value_or(false); + // Step 0 ---- // Rewrite filter - newFindAndModifyRequest.setQuery(fle::rewriteEncryptedFilterInsideTxn( - queryImpl, edcNss.db(), efc, expCtx, findAndModifyRequest.getQuery())); + auto highCardinalityModeAllowed = findAndModifyRequest.getUpsert().value_or(false) + ? fle::HighCardinalityModeAllowed::kDisallow + : fle::HighCardinalityModeAllowed::kAllow; + + newFindAndModifyRequest.setQuery( + fle::rewriteEncryptedFilterInsideTxn(queryImpl, + edcNss.db(), + efc, + expCtx, + findAndModifyRequest.getQuery(), + highCardinalityModeAllowed)); // Make sure not to inherit the command's writeConcern, this should be set at the transaction // level. @@ -1025,9 +1077,11 @@ write_ops::FindAndModifyCommandReply processFindAndModify( if (updateModification.type() == write_ops::UpdateModification::Type::kModifier) { auto updateModifier = updateModification.getUpdateModifier(); auto setObject = updateModifier.getObjectField("$set"); - EDCServerCollection::validateEncryptedFieldInfo(setObject, efc); + EDCServerCollection::validateEncryptedFieldInfo( + setObject, efc, bypassDocumentValidation); serverPayload = EDCServerCollection::getEncryptedFieldInfo(updateModifier); - processFieldsForInsert(queryImpl, edcNss, serverPayload, efc, &stmtId); + processFieldsForInsert( + queryImpl, edcNss, serverPayload, efc, &stmtId, bypassDocumentValidation); auto pushUpdate = EDCServerCollection::finalizeForUpdate(updateModifier, serverPayload); @@ -1036,10 +1090,12 @@ write_ops::FindAndModifyCommandReply processFindAndModify( pushUpdate, write_ops::UpdateModification::ClassicTag(), false); } else { auto replacementDocument = updateModification.getUpdateReplacement(); - EDCServerCollection::validateEncryptedFieldInfo(replacementDocument, efc); + EDCServerCollection::validateEncryptedFieldInfo( + replacementDocument, efc, bypassDocumentValidation); serverPayload = EDCServerCollection::getEncryptedFieldInfo(replacementDocument); - processFieldsForInsert(queryImpl, edcNss, serverPayload, efc, &stmtId); + processFieldsForInsert( + queryImpl, edcNss, serverPayload, efc, &stmtId, bypassDocumentValidation); // Step 2 ---- auto safeContentReplace = @@ -1131,8 +1187,17 @@ write_ops::FindAndModifyCommandRequest processFindAndModifyExplain( auto efc = EncryptionInformationHelpers::getAndValidateSchema(edcNss, ei); auto newFindAndModifyRequest = findAndModifyRequest; - newFindAndModifyRequest.setQuery(fle::rewriteEncryptedFilterInsideTxn( - queryImpl, edcNss.db(), efc, expCtx, findAndModifyRequest.getQuery())); + auto highCardinalityModeAllowed = findAndModifyRequest.getUpsert().value_or(false) + ? fle::HighCardinalityModeAllowed::kDisallow + : fle::HighCardinalityModeAllowed::kAllow; + + newFindAndModifyRequest.setQuery( + fle::rewriteEncryptedFilterInsideTxn(queryImpl, + edcNss.db(), + efc, + expCtx, + findAndModifyRequest.getQuery(), + highCardinalityModeAllowed)); newFindAndModifyRequest.setEncryptionInformation(boost::none); return newFindAndModifyRequest; @@ -1234,10 +1299,23 @@ uint64_t FLEQueryInterfaceImpl::countDocuments(const NamespaceString& nss) { } StatusWith<write_ops::InsertCommandReply> FLEQueryInterfaceImpl::insertDocument( - const NamespaceString& nss, BSONObj obj, StmtId* pStmtId, bool translateDuplicateKey) { + const NamespaceString& nss, + BSONObj obj, + StmtId* pStmtId, + bool translateDuplicateKey, + bool bypassDocumentValidation) { write_ops::InsertCommandRequest insertRequest(nss); insertRequest.setDocuments({obj}); + EncryptionInformation encryptionInformation; + encryptionInformation.setCrudProcessed(true); + + // We need to set an empty BSON object here for the schema. + encryptionInformation.setSchema(BSONObj()); + insertRequest.getWriteCommandRequestBase().setEncryptionInformation(encryptionInformation); + insertRequest.getWriteCommandRequestBase().setBypassDocumentValidation( + bypassDocumentValidation); + int32_t stmtId = *pStmtId; if (stmtId != kUninitializedStmtId) { (*pStmtId)++; @@ -1322,6 +1400,7 @@ std::pair<write_ops::UpdateCommandReply, BSONObj> FLEQueryInterfaceImpl::updateW findAndModifyRequest.setLet( mergeLetAndCVariables(updateRequest.getLet(), updateOpEntry.getC())); findAndModifyRequest.setStmtId(updateRequest.getStmtId()); + findAndModifyRequest.setBypassDocumentValidation(updateRequest.getBypassDocumentValidation()); auto ei2 = ei; ei2.setCrudProcessed(true); @@ -1363,9 +1442,15 @@ std::pair<write_ops::UpdateCommandReply, BSONObj> FLEQueryInterfaceImpl::updateW } write_ops::UpdateCommandReply FLEQueryInterfaceImpl::update( - const NamespaceString& nss, - int32_t stmtId, - const write_ops::UpdateCommandRequest& updateRequest) { + const NamespaceString& nss, int32_t stmtId, write_ops::UpdateCommandRequest& updateRequest) { + + invariant(!updateRequest.getWriteCommandRequestBase().getEncryptionInformation()); + + EncryptionInformation encryptionInformation; + encryptionInformation.setCrudProcessed(true); + + encryptionInformation.setSchema(BSONObj()); + updateRequest.getWriteCommandRequestBase().setEncryptionInformation(encryptionInformation); dassert(updateRequest.getStmtIds().value_or(std::vector<int32_t>()).empty()); diff --git a/src/mongo/db/fle_crud.h b/src/mongo/db/fle_crud.h index 738e85b8996..7c8d93ae1f9 100644 --- a/src/mongo/db/fle_crud.h +++ b/src/mongo/db/fle_crud.h @@ -261,7 +261,11 @@ public: * FLEStateCollectionContention instead. */ virtual StatusWith<write_ops::InsertCommandReply> insertDocument( - const NamespaceString& nss, BSONObj obj, StmtId* pStmtId, bool translateDuplicateKey) = 0; + const NamespaceString& nss, + BSONObj obj, + StmtId* pStmtId, + bool translateDuplicateKey, + bool bypassDocumentValidation = false) = 0; /** * Delete a single document with the given query. @@ -294,7 +298,7 @@ public: virtual write_ops::UpdateCommandReply update( const NamespaceString& nss, int32_t stmtId, - const write_ops::UpdateCommandRequest& updateRequest) = 0; + write_ops::UpdateCommandRequest& updateRequest) = 0; /** * Do a single findAndModify request. @@ -325,10 +329,12 @@ public: uint64_t countDocuments(const NamespaceString& nss) final; - StatusWith<write_ops::InsertCommandReply> insertDocument(const NamespaceString& nss, - BSONObj obj, - int32_t* pStmtId, - bool translateDuplicateKey) final; + StatusWith<write_ops::InsertCommandReply> insertDocument( + const NamespaceString& nss, + BSONObj obj, + int32_t* pStmtId, + bool translateDuplicateKey, + bool bypassDocumentValidation = false) final; std::pair<write_ops::DeleteCommandReply, BSONObj> deleteWithPreimage( const NamespaceString& nss, @@ -340,10 +346,9 @@ public: const EncryptionInformation& ei, const write_ops::UpdateCommandRequest& updateRequest) final; - write_ops::UpdateCommandReply update( - const NamespaceString& nss, - int32_t stmtId, - const write_ops::UpdateCommandRequest& updateRequest) final; + write_ops::UpdateCommandReply update(const NamespaceString& nss, + int32_t stmtId, + write_ops::UpdateCommandRequest& updateRequest) final; write_ops::FindAndModifyCommandReply findAndModify( const NamespaceString& nss, @@ -408,7 +413,8 @@ StatusWith<write_ops::InsertCommandReply> processInsert( std::vector<EDCServerPayloadInfo>& serverPayload, const EncryptedFieldConfig& efc, int32_t stmtId, - BSONObj document); + BSONObj document, + bool bypassDocumentValidation = false); /** * Process a FLE delete with the query interface diff --git a/src/mongo/db/fle_crud_mongod.cpp b/src/mongo/db/fle_crud_mongod.cpp index 68327133c88..1e488f1f65a 100644 --- a/src/mongo/db/fle_crud_mongod.cpp +++ b/src/mongo/db/fle_crud_mongod.cpp @@ -284,7 +284,13 @@ BSONObj processFLEWriteExplainD(OperationContext* opCtx, const BSONObj& query) { auto expCtx = make_intrusive<ExpressionContext>( opCtx, fle::collatorFromBSON(opCtx, collation), nss, runtimeConstants, letParameters); - return fle::rewriteQuery(opCtx, expCtx, nss, info, query, &getTransactionWithRetriesForMongoD); + return fle::rewriteQuery(opCtx, + expCtx, + nss, + info, + query, + &getTransactionWithRetriesForMongoD, + fle::HighCardinalityModeAllowed::kAllow); } std::pair<write_ops::FindAndModifyCommandRequest, OpMsgRequest> diff --git a/src/mongo/db/fle_crud_test.cpp b/src/mongo/db/fle_crud_test.cpp index 527dd5bca11..0a5d7dfc37c 100644 --- a/src/mongo/db/fle_crud_test.cpp +++ b/src/mongo/db/fle_crud_test.cpp @@ -27,6 +27,7 @@ * it in the license file. */ +#include "mongo/base/error_codes.h" #include "mongo/platform/basic.h" #include <algorithm> @@ -153,8 +154,12 @@ protected: void assertDocumentCounts(uint64_t edc, uint64_t esc, uint64_t ecc, uint64_t ecoc); - void doSingleInsert(int id, BSONElement element); - void doSingleInsert(int id, BSONObj obj); + void testValidateEncryptedFieldInfo(BSONObj obj, bool bypassValidation); + + void testValidateTags(BSONObj obj); + + void doSingleInsert(int id, BSONElement element, bool bypassDocumentValidation = false); + void doSingleInsert(int id, BSONObj obj, bool bypassDocumentValidation = false); void doSingleInsertWithContention( int id, BSONElement element, int64_t cm, uint64_t cf, EncryptedFieldConfig efc); @@ -406,7 +411,7 @@ void FleCrudTest::doSingleWideInsert(int id, uint64_t fieldCount, ValueGenerator auto efc = getTestEncryptedFieldConfig(); - uassertStatusOK(processInsert(_queryImpl.get(), _edcNs, serverPayload, efc, 0, result)); + uassertStatusOK(processInsert(_queryImpl.get(), _edcNs, serverPayload, efc, 0, result, false)); } @@ -451,7 +456,16 @@ std::vector<char> generateSinglePlaceholder(BSONElement value, int64_t cm = 0) { return v; } -void FleCrudTest::doSingleInsert(int id, BSONElement element) { +void FleCrudTest::testValidateEncryptedFieldInfo(BSONObj obj, bool bypassValidation) { + auto efc = getTestEncryptedFieldConfig(); + EDCServerCollection::validateEncryptedFieldInfo(obj, efc, bypassValidation); +} + +void FleCrudTest::testValidateTags(BSONObj obj) { + FLEClientCrypto::validateTagsArray(obj); +} + +void FleCrudTest::doSingleInsert(int id, BSONElement element, bool bypassDocumentValidation) { auto buf = generateSinglePlaceholder(element); BSONObjBuilder builder; builder.append("_id", id); @@ -467,10 +481,10 @@ void FleCrudTest::doSingleInsert(int id, BSONElement element) { auto efc = getTestEncryptedFieldConfig(); - uassertStatusOK(processInsert(_queryImpl.get(), _edcNs, serverPayload, efc, 0, result)); + uassertStatusOK(processInsert(_queryImpl.get(), _edcNs, serverPayload, efc, 0, result, false)); } -void FleCrudTest::doSingleInsert(int id, BSONObj obj) { +void FleCrudTest::doSingleInsert(int id, BSONObj obj, bool bypassDocumentValidation) { doSingleInsert(id, obj.firstElement()); } @@ -490,7 +504,7 @@ void FleCrudTest::doSingleInsertWithContention( auto serverPayload = EDCServerCollection::getEncryptedFieldInfo(result); - uassertStatusOK(processInsert(_queryImpl.get(), _edcNs, serverPayload, efc, 0, result)); + uassertStatusOK(processInsert(_queryImpl.get(), _edcNs, serverPayload, efc, 0, result, false)); } void FleCrudTest::doSingleInsertWithContention( @@ -890,7 +904,6 @@ TEST_F(FleCrudTest, UpdateOneSameValue) { << "secret")); } - // Update one document with replacement TEST_F(FleCrudTest, UpdateOneReplace) { @@ -956,7 +969,16 @@ TEST_F(FleCrudTest, SetSafeContent) { builder.append("$set", BSON(kSafeContent << "foo")); auto result = builder.obj(); - ASSERT_THROWS_CODE(doSingleUpdateWithUpdateDoc(1, result), DBException, 6371507); + ASSERT_THROWS_CODE(doSingleUpdateWithUpdateDoc(1, result), DBException, 6666200); +} + +// Test that EDCServerCollection::validateEncryptedFieldInfo checks that the +// safeContent cannot be present in the BSON obj. +TEST_F(FleCrudTest, testValidateEncryptedFieldConfig) { + testValidateEncryptedFieldInfo(BSON(kSafeContent << "secret"), true); + ASSERT_THROWS_CODE(testValidateEncryptedFieldInfo(BSON(kSafeContent << "secret"), false), + DBException, + 6666200); } // Update one document via findAndModify @@ -1038,6 +1060,11 @@ TEST_F(FleCrudTest, FindAndModify_RenameSafeContent) { ASSERT_THROWS_CODE(doFindAndModify(req), DBException, 6371506); } +TEST_F(FleCrudTest, validateTagsTest) { + testValidateTags(BSON(kSafeContent << BSON_ARRAY(123))); + ASSERT_THROWS_CODE(testValidateTags(BSON(kSafeContent << "foo")), DBException, 6371507); +} + // Mess with __safeContent__ and ensure the update errors TEST_F(FleCrudTest, FindAndModify_SetSafeContent) { doSingleInsert(1, @@ -1056,8 +1083,7 @@ TEST_F(FleCrudTest, FindAndModify_SetSafeContent) { req.setUpdate( write_ops::UpdateModification(result, write_ops::UpdateModification::ClassicTag{}, false)); - - ASSERT_THROWS_CODE(doFindAndModify(req), DBException, 6371507); + ASSERT_THROWS_CODE(doFindAndModify(req), DBException, 6666200); } TEST_F(FleTagsTest, InsertOne) { @@ -1199,7 +1225,7 @@ TEST_F(FleTagsTest, MemoryLimit) { doSingleInsert(10, doc); // readTags returns 11 tags which does exceed memory limit. - ASSERT_THROWS_CODE(readTags(doc), DBException, 6401800); + ASSERT_THROWS_CODE(readTags(doc), DBException, ErrorCodes::FLEMaxTagLimitExceeded); doSingleDelete(5); diff --git a/src/mongo/db/fle_query_interface_mock.cpp b/src/mongo/db/fle_query_interface_mock.cpp index 2aeb39788dd..b5ca4e1e9cd 100644 --- a/src/mongo/db/fle_query_interface_mock.cpp +++ b/src/mongo/db/fle_query_interface_mock.cpp @@ -54,7 +54,11 @@ uint64_t FLEQueryInterfaceMock::countDocuments(const NamespaceString& nss) { } StatusWith<write_ops::InsertCommandReply> FLEQueryInterfaceMock::insertDocument( - const NamespaceString& nss, BSONObj obj, StmtId* pStmtId, bool translateDuplicateKey) { + const NamespaceString& nss, + BSONObj obj, + StmtId* pStmtId, + bool translateDuplicateKey, + bool bypassDocumentValidation) { repl::TimestampedBSONObj tb; tb.obj = obj; @@ -132,9 +136,7 @@ std::pair<write_ops::UpdateCommandReply, BSONObj> FLEQueryInterfaceMock::updateW } write_ops::UpdateCommandReply FLEQueryInterfaceMock::update( - const NamespaceString& nss, - int32_t stmtId, - const write_ops::UpdateCommandRequest& updateRequest) { + const NamespaceString& nss, int32_t stmtId, write_ops::UpdateCommandRequest& updateRequest) { auto [reply, _] = updateWithPreimage(nss, EncryptionInformation(), updateRequest); return reply; } diff --git a/src/mongo/db/fle_query_interface_mock.h b/src/mongo/db/fle_query_interface_mock.h index 229d2c08dfe..a89fc71ce1e 100644 --- a/src/mongo/db/fle_query_interface_mock.h +++ b/src/mongo/db/fle_query_interface_mock.h @@ -47,10 +47,12 @@ public: uint64_t countDocuments(const NamespaceString& nss) final; - StatusWith<write_ops::InsertCommandReply> insertDocument(const NamespaceString& nss, - BSONObj obj, - StmtId* pStmtId, - bool translateDuplicateKey) final; + StatusWith<write_ops::InsertCommandReply> insertDocument( + const NamespaceString& nss, + BSONObj obj, + StmtId* pStmtId, + bool translateDuplicateKey, + bool bypassDocumentValidation = false) final; std::pair<write_ops::DeleteCommandReply, BSONObj> deleteWithPreimage( const NamespaceString& nss, @@ -62,10 +64,9 @@ public: const EncryptionInformation& ei, const write_ops::UpdateCommandRequest& updateRequest) final; - write_ops::UpdateCommandReply update( - const NamespaceString& nss, - int32_t stmtId, - const write_ops::UpdateCommandRequest& updateRequest) final; + write_ops::UpdateCommandReply update(const NamespaceString& nss, + int32_t stmtId, + write_ops::UpdateCommandRequest& updateRequest) final; write_ops::FindAndModifyCommandReply findAndModify( const NamespaceString& nss, diff --git a/src/mongo/db/index_builds_coordinator.cpp b/src/mongo/db/index_builds_coordinator.cpp index 1e06ec2932a..45c1afefdae 100644 --- a/src/mongo/db/index_builds_coordinator.cpp +++ b/src/mongo/db/index_builds_coordinator.cpp @@ -29,10 +29,10 @@ #define MONGO_LOGV2_DEFAULT_COMPONENT ::mongo::logv2::LogComponent::kStorage -#include "mongo/platform/basic.h" - #include "mongo/db/index_builds_coordinator.h" +#include <fmt/format.h> + #include "mongo/db/catalog/clustered_collection_util.h" #include "mongo/db/catalog/collection_catalog.h" #include "mongo/db/catalog/commit_quorum_options.h" @@ -1659,19 +1659,39 @@ void IndexBuildsCoordinator::assertNoIndexBuildInProgress() const { void IndexBuildsCoordinator::assertNoIndexBuildInProgForCollection( const UUID& collectionUUID) const { + boost::optional<UUID> firstIndexBuildUUID; + auto indexBuilds = activeIndexBuilds.filterIndexBuilds([&](const auto& replState) { + auto isIndexBuildForCollection = (collectionUUID == replState.collectionUUID); + if (isIndexBuildForCollection && !firstIndexBuildUUID) { + firstIndexBuildUUID = replState.buildUUID; + }; + return isIndexBuildForCollection; + }); + uassert(ErrorCodes::BackgroundOperationInProgressForNamespace, - str::stream() << "cannot perform operation: an index build is currently running for " - "collection with UUID: " - << collectionUUID, - !inProgForCollection(collectionUUID)); + fmt::format("cannot perform operation: an index build is currently running for " + "collection with UUID: {}. Found index build: {}", + collectionUUID.toString(), + firstIndexBuildUUID->toString()), + indexBuilds.empty()); } void IndexBuildsCoordinator::assertNoBgOpInProgForDb(StringData db) const { + boost::optional<UUID> firstIndexBuildUUID; + auto indexBuilds = activeIndexBuilds.filterIndexBuilds([&](const auto& replState) { + auto isIndexBuildForCollection = (db == replState.dbName); + if (isIndexBuildForCollection && !firstIndexBuildUUID) { + firstIndexBuildUUID = replState.buildUUID; + }; + return isIndexBuildForCollection; + }); + uassert(ErrorCodes::BackgroundOperationInProgressForDatabase, - str::stream() << "cannot perform operation: an index build is currently running for " - "database " - << db, - !inProgForDb(db)); + fmt::format("cannot perform operation: an index build is currently running for " + "database {}. Found index build: {}", + db, + firstIndexBuildUUID->toString()), + indexBuilds.empty()); } void IndexBuildsCoordinator::awaitNoIndexBuildInProgressForCollection(OperationContext* opCtx, diff --git a/src/mongo/db/index_builds_coordinator_mongod_test.cpp b/src/mongo/db/index_builds_coordinator_mongod_test.cpp index 28e09af1ef7..ecca38a9d94 100644 --- a/src/mongo/db/index_builds_coordinator_mongod_test.cpp +++ b/src/mongo/db/index_builds_coordinator_mongod_test.cpp @@ -86,6 +86,9 @@ void IndexBuildsCoordinatorMongodTest::setUp() { } void IndexBuildsCoordinatorMongodTest::tearDown() { + // Resume index builds left running by test failures so that shutdown() will not block. + _indexBuildsCoord->sleepIndexBuilds_forTestOnly(false); + _indexBuildsCoord->shutdown(operationContext()); _indexBuildsCoord.reset(); // All databases are dropped during tear down. @@ -155,24 +158,27 @@ TEST_F(IndexBuildsCoordinatorMongodTest, Registration) { _indexBuildsCoord->sleepIndexBuilds_forTestOnly(true); // Register an index build on _testFooNss. + auto testFoo1BuildUUID = UUID::gen(); auto testFoo1Future = assertGet(_indexBuildsCoord->startIndexBuild(operationContext(), _testFooNss.db().toString(), _testFooUUID, makeSpecs(_testFooNss, {"a", "b"}), - UUID::gen(), + testFoo1BuildUUID, IndexBuildProtocol::kTwoPhase, _indexBuildOptions)); ASSERT_EQ(_indexBuildsCoord->numInProgForDb(_testFooNss.db()), 1); ASSERT(_indexBuildsCoord->inProgForCollection(_testFooUUID)); ASSERT(_indexBuildsCoord->inProgForDb(_testFooNss.db())); - ASSERT_THROWS_CODE(_indexBuildsCoord->assertNoIndexBuildInProgForCollection(_testFooUUID), - AssertionException, - ErrorCodes::BackgroundOperationInProgressForNamespace); - ASSERT_THROWS_CODE(_indexBuildsCoord->assertNoBgOpInProgForDb(_testFooNss.db()), - AssertionException, - ErrorCodes::BackgroundOperationInProgressForDatabase); + ASSERT_THROWS_WITH_CHECK( + _indexBuildsCoord->assertNoIndexBuildInProgForCollection(_testFooUUID), + ExceptionFor<ErrorCodes::BackgroundOperationInProgressForNamespace>, + [&](const auto& ex) { ASSERT_STRING_CONTAINS(ex.reason(), testFoo1BuildUUID.toString()); }); + ASSERT_THROWS_WITH_CHECK( + _indexBuildsCoord->assertNoBgOpInProgForDb(_testFooNss.db()), + ExceptionFor<ErrorCodes::BackgroundOperationInProgressForDatabase>, + [&](const auto& ex) { ASSERT_STRING_CONTAINS(ex.reason(), testFoo1BuildUUID.toString()); }); // Register a second index build on _testFooNss. auto testFoo2Future = @@ -382,7 +388,10 @@ TEST_F(IndexBuildsCoordinatorMongodTest, AbortBuildIndexDueToTenantMigration) { // we currently have one index build in progress. ASSERT_EQ(1, _indexBuildsCoord->getActiveIndexBuildCount(operationContext())); - ASSERT_THROWS(_indexBuildsCoord->assertNoIndexBuildInProgress(), mongo::DBException); + ASSERT_THROWS_WITH_CHECK( + _indexBuildsCoord->assertNoIndexBuildInProgress(), + ExceptionFor<ErrorCodes::BackgroundOperationInProgressForDatabase>, + [&](const auto& ex) { ASSERT_STRING_CONTAINS(ex.reason(), buildUUID.toString()); }); ASSERT_OK(_indexBuildsCoord->voteCommitIndexBuild( operationContext(), buildUUID, HostAndPort("test1", 1234))); diff --git a/src/mongo/db/internal_transactions_feature_flag.idl b/src/mongo/db/internal_transactions_feature_flag.idl index d0373f56140..bbbb9fa1477 100644 --- a/src/mongo/db/internal_transactions_feature_flag.idl +++ b/src/mongo/db/internal_transactions_feature_flag.idl @@ -41,6 +41,11 @@ feature_flags: default: true version: 6.0 + featureFlagAlwaysCreateConfigTransactionsPartialIndexOnStepUp: + description: Feature flag to enable always creating the config.transactions partial index on step up to primary even if the collection is not empty. + cpp_varname: gFeatureFlagAlwaysCreateConfigTransactionsPartialIndexOnStepUp + default: false + featureFlagUpdateDocumentShardKeyUsingTransactionApi: description: Feature flag to enable usage of the transaction api for update findAndModify and update commands that change a document's shard key. cpp_varname: gFeatureFlagUpdateDocumentShardKeyUsingTransactionApi diff --git a/src/mongo/db/ops/write_ops_exec.cpp b/src/mongo/db/ops/write_ops_exec.cpp index 6fd7e2200c7..f7fc2a84efd 100644 --- a/src/mongo/db/ops/write_ops_exec.cpp +++ b/src/mongo/db/ops/write_ops_exec.cpp @@ -602,11 +602,36 @@ SingleWriteResult makeWriteResultForInsertOrDeleteRetry() { return res; } + +// Returns the flags that determine the type of document validation we want to +// perform. First item in the tuple determines whether to bypass document validation altogether, +// second item determines if _safeContent_ array can be modified in an encrypted collection. +std::tuple<bool, bool> getDocumentValidationFlags(OperationContext* opCtx, + const write_ops::WriteCommandRequestBase& req) { + auto& encryptionInfo = req.getEncryptionInformation(); + const bool fleCrudProcessed = getFleCrudProcessed(opCtx, encryptionInfo); + return std::make_tuple(req.getBypassDocumentValidation(), fleCrudProcessed); +} } // namespace +bool getFleCrudProcessed(OperationContext* opCtx, + const boost::optional<EncryptionInformation>& encryptionInfo) { + if (encryptionInfo && encryptionInfo->getCrudProcessed().value_or(false)) { + uassert(6666201, + "External users cannot have crudProcessed enabled", + AuthorizationSession::get(opCtx->getClient()) + ->isAuthorizedForActionsOnResource(ResourcePattern::forClusterResource(), + ActionType::internal)); + + return true; + } + return false; +} + WriteResult performInserts(OperationContext* opCtx, const write_ops::InsertCommandRequest& wholeOp, OperationSource source) { + // Insert performs its own retries, so we should only be within a WriteUnitOfWork when run in a // transaction. auto txnParticipant = TransactionParticipant::get(opCtx); @@ -641,8 +666,15 @@ WriteResult performInserts(OperationContext* opCtx, uassertStatusOK(userAllowedWriteNS(opCtx, wholeOp.getNamespace())); } - DisableDocumentSchemaValidationIfTrue docSchemaValidationDisabler( - opCtx, wholeOp.getWriteCommandRequestBase().getBypassDocumentValidation()); + const auto [disableDocumentValidation, fleCrudProcessed] = + getDocumentValidationFlags(opCtx, wholeOp.getWriteCommandRequestBase()); + + DisableDocumentSchemaValidationIfTrue docSchemaValidationDisabler(opCtx, + disableDocumentValidation); + + DisableSafeContentValidationIfTrue safeContentValidationDisabler( + opCtx, disableDocumentValidation, fleCrudProcessed); + LastOpFixer lastOpFixer(opCtx, wholeOp.getNamespace()); WriteResult out; @@ -1000,8 +1032,15 @@ WriteResult performUpdates(OperationContext* opCtx, (txnParticipant && opCtx->inMultiDocumentTransaction())); uassertStatusOK(userAllowedWriteNS(opCtx, ns)); - DisableDocumentSchemaValidationIfTrue docSchemaValidationDisabler( - opCtx, wholeOp.getWriteCommandRequestBase().getBypassDocumentValidation()); + const auto [disableDocumentValidation, fleCrudProcessed] = + getDocumentValidationFlags(opCtx, wholeOp.getWriteCommandRequestBase()); + + DisableDocumentSchemaValidationIfTrue docSchemaValidationDisabler(opCtx, + disableDocumentValidation); + + DisableSafeContentValidationIfTrue safeContentValidationDisabler( + opCtx, disableDocumentValidation, fleCrudProcessed); + LastOpFixer lastOpFixer(opCtx, ns); bool containsRetry = false; @@ -1227,8 +1266,15 @@ WriteResult performDeletes(OperationContext* opCtx, (txnParticipant && opCtx->inMultiDocumentTransaction())); uassertStatusOK(userAllowedWriteNS(opCtx, ns)); - DisableDocumentSchemaValidationIfTrue docSchemaValidationDisabler( - opCtx, wholeOp.getWriteCommandRequestBase().getBypassDocumentValidation()); + const auto [disableDocumentValidation, fleCrudProcessed] = + getDocumentValidationFlags(opCtx, wholeOp.getWriteCommandRequestBase()); + + DisableDocumentSchemaValidationIfTrue docSchemaValidationDisabler(opCtx, + disableDocumentValidation); + + DisableSafeContentValidationIfTrue safeContentValidationDisabler( + opCtx, disableDocumentValidation, fleCrudProcessed); + LastOpFixer lastOpFixer(opCtx, ns); bool containsRetry = false; diff --git a/src/mongo/db/ops/write_ops_exec.h b/src/mongo/db/ops/write_ops_exec.h index 548a3034713..3550a51c1ce 100644 --- a/src/mongo/db/ops/write_ops_exec.h +++ b/src/mongo/db/ops/write_ops_exec.h @@ -64,6 +64,9 @@ struct WriteResult { bool canContinue = true; }; +bool getFleCrudProcessed(OperationContext* opCtx, + const boost::optional<EncryptionInformation>& encryptionInfo); + /** * Performs a batch of inserts, updates, or deletes. * diff --git a/src/mongo/db/pipeline/SConscript b/src/mongo/db/pipeline/SConscript index 6b14f635095..16e39fc4832 100644 --- a/src/mongo/db/pipeline/SConscript +++ b/src/mongo/db/pipeline/SConscript @@ -99,6 +99,7 @@ env.Library( 'expression_context.cpp', 'expression_function.cpp', 'expression_js_emit.cpp', + 'expression_parser.idl', 'expression_test_api_version.cpp', 'expression_trigonometric.cpp', 'javascript_execution.cpp', @@ -106,6 +107,7 @@ env.Library( 'variables.cpp', ], LIBDEPS=[ + '$BUILD_DIR/mongo/crypto/fle_crypto', '$BUILD_DIR/mongo/db/bson/dotted_path_support', '$BUILD_DIR/mongo/db/commands/test_commands_enabled', '$BUILD_DIR/mongo/db/exec/document_value/document_value', @@ -128,6 +130,7 @@ env.Library( LIBDEPS_PRIVATE=[ '$BUILD_DIR/mongo/db/mongohasher', '$BUILD_DIR/mongo/db/vector_clock', + '$BUILD_DIR/mongo/idl/idl_parser', ], ) diff --git a/src/mongo/db/pipeline/abt/agg_expression_visitor.cpp b/src/mongo/db/pipeline/abt/agg_expression_visitor.cpp index 3d0c8453a47..55b44e4d37b 100644 --- a/src/mongo/db/pipeline/abt/agg_expression_visitor.cpp +++ b/src/mongo/db/pipeline/abt/agg_expression_visitor.cpp @@ -302,6 +302,10 @@ public: unsupportedExpression(expr->getOpName()); } + void visit(const ExpressionInternalFLEEqual* expr) override final { + unsupportedExpression(expr->getOpName()); + } + void visit(const ExpressionMap* expr) override final { unsupportedExpression("$map"); } diff --git a/src/mongo/db/pipeline/dependencies.cpp b/src/mongo/db/pipeline/dependencies.cpp index 8b60a31637c..d2a5563c7c7 100644 --- a/src/mongo/db/pipeline/dependencies.cpp +++ b/src/mongo/db/pipeline/dependencies.cpp @@ -37,6 +37,13 @@ namespace mongo { +std::list<std::string> DepsTracker::sortedFields() const { + // Use a special comparator to put parent fieldpaths before their children. + std::list<std::string> sortedFields(fields.begin(), fields.end()); + sortedFields.sort(PathPrefixComparator()); + return sortedFields; +} + BSONObj DepsTracker::toProjectionWithoutMetadata( TruncateToRootLevel truncationBehavior /*= TruncateToRootLevel::no*/) const { BSONObjBuilder bb; @@ -52,17 +59,21 @@ BSONObj DepsTracker::toProjectionWithoutMetadata( return bb.obj(); } + // Go through dependency fieldpaths to find the minimal set of projections that cover the + // dependencies. For example, the dependencies ["a.b", "a.b.c.g", "c", "c.d", "f"] would be + // minimally covered by the projection {"a.b": 1, "c": 1, "f": 1}. The key operation here is + // folding dependencies into ancestor dependencies, wherever possible. This is assisted by a + // special sort in DepsTracker::sortedFields that treats '.' as the first char and thus places + // parent paths directly before their children. bool idSpecified = false; std::string last; - for (const auto& field : fields) { + for (const auto& field : sortedFields()) { if (str::startsWith(field, "_id") && (field.size() == 3 || field[3] == '.')) { idSpecified = true; } if (!last.empty() && str::startsWith(field, last)) { - // we are including a parent of *it so we don't need to include this field - // explicitly. This logic relies on on set iterators going in lexicographic order so - // that a string is always directly before of all fields it prefixes. + // We are including a parent of this field, so we can skip this field. continue; } @@ -96,4 +107,36 @@ void DepsTracker::setNeedsMetadata(DocumentMetadataFields::MetaType type, bool r invariant(required || !_metadataDeps[type]); _metadataDeps[type] = required; } + +// Returns true if the lhs value should sort before the rhs, false otherwise. +bool PathPrefixComparator::operator()(const std::string& lhs, const std::string& rhs) const { + constexpr char dot = '.'; + + for (size_t pos = 0, len = std::min(lhs.size(), rhs.size()); pos < len; ++pos) { + // Below, we explicitly choose unsigned char because the usual const char& returned by + // operator[] is actually signed on x86 and will incorrectly order unicode characters. + unsigned char lchar = lhs[pos], rchar = rhs[pos]; + if (lchar == rchar) { + continue; + } + + // Consider the path delimiter '.' as being less than all other characters, so that + // paths sort directly before any paths they prefix and directly after any paths + // which prefix them. + if (lchar == dot) { + return true; + } else if (rchar == dot) { + return false; + } + + // Otherwise, default to normal character comparison. + return lchar < rchar; + } + + // If we get here, then we have reached the end of lhs and/or rhs and all of their path + // segments up to this point match. If lhs is shorter than rhs, then lhs prefixes rhs + // and should sort before it. + return lhs.size() < rhs.size(); +} + } // namespace mongo diff --git a/src/mongo/db/pipeline/dependencies.h b/src/mongo/db/pipeline/dependencies.h index bda3bf9b243..3c892de8181 100644 --- a/src/mongo/db/pipeline/dependencies.h +++ b/src/mongo/db/pipeline/dependencies.h @@ -184,6 +184,11 @@ struct DepsTracker { } } + /** + * Return fieldpaths ordered such that a parent is immediately before its children. + */ + std::list<std::string> sortedFields() const; + std::set<std::string> fields; // Names of needed fields in dotted notation. std::set<Variables::Id> vars; // IDs of referenced variables. bool needWholeDocument = false; // If true, ignore 'fields'; the whole document is needed. @@ -201,4 +206,13 @@ private: // dependency analysis. QueryMetadataBitSet _metadataDeps; }; + + +/** Custom comparator that orders fieldpath strings by path prefix first, then by field. + * This ensures that a parent field is ordered directly before its children. + */ +struct PathPrefixComparator { + /* Returns true if the lhs value should sort before the rhs, false otherwise. */ + bool operator()(const std::string& lhs, const std::string& rhs) const; +}; } // namespace mongo diff --git a/src/mongo/db/pipeline/dependencies_test.cpp b/src/mongo/db/pipeline/dependencies_test.cpp index f366ad3ce1d..938130b91bd 100644 --- a/src/mongo/db/pipeline/dependencies_test.cpp +++ b/src/mongo/db/pipeline/dependencies_test.cpp @@ -162,6 +162,13 @@ TEST(DependenciesToProjectionTest, ShouldIncludeFieldEvenIfSuffixOfAnotherFieldW BSON("a" << 1 << "ab" << 1 << "_id" << 0)); } +TEST(DependenciesToProjectionTest, ExcludeIndirectDescendants) { + const char* array[] = {"a.b", "_id", "a.b.c.d.e"}; + DepsTracker deps; + deps.fields = arrayToSet(array); + ASSERT_BSONOBJ_EQ(deps.toProjectionWithoutMetadata(), BSON("_id" << 1 << "a.b" << 1)); +} + TEST(DependenciesToProjectionTest, ShouldIncludeIdIfNeeded) { const char* array[] = {"a", "_id"}; DepsTracker deps; @@ -199,6 +206,27 @@ TEST(DependenciesToProjectionTest, ShouldIncludeFieldPrefixedByIdWhenIdSubfieldI BSON("_id.a" << 1 << "_id_a" << 1 << "a" << 1)); } +// SERVER-66418 +TEST(DependenciesToProjectionTest, ChildCoveredByParentWithSpecialChars) { + // without "_id" + { + // This is an important test case because '-' is one of the few chars before '.' in utf-8. + const char* array[] = {"a", "a-b", "a.b"}; + DepsTracker deps; + deps.fields = arrayToSet(array); + ASSERT_BSONOBJ_EQ(deps.toProjectionWithoutMetadata(), + BSON("a" << 1 << "a-b" << 1 << "_id" << 0)); + } + // with "_id" + { + const char* array[] = {"_id", "a", "a-b", "a.b"}; + DepsTracker deps; + deps.fields = arrayToSet(array); + ASSERT_BSONOBJ_EQ(deps.toProjectionWithoutMetadata(), + BSON("_id" << 1 << "a" << 1 << "a-b" << 1)); + } +} + TEST(DependenciesToProjectionTest, ShouldOutputEmptyObjectIfEntireDocumentNeeded) { const char* array[] = {"a"}; // fields ignored with needWholeDocument DepsTracker deps; @@ -259,5 +287,56 @@ TEST(DependenciesToProjectionTest, ASSERT_TRUE(deps.metadataDeps()[DocumentMetadataFields::kTextScore]); } +TEST(DependenciesToProjectionTest, SortFieldPaths) { + const char* array[] = {"", + "A", + "_id", + "a", + "a.b", + "a.b.c", + "a.c", + // '-' char in utf-8 comes before '.' but our special fieldpath sort + // puts '.' first so that children directly follow their parents. + "a-b", + "a-b.ear", + "a-bear", + "a-bear.", + "a🌲", + "b", + "b.a" + "b.aa" + "b.🌲d"}; + DepsTracker deps; + deps.fields = arrayToSet(array); + // our custom sort will restore the ordering above + std::list<std::string> fieldPathSorted = deps.sortedFields(); + auto itr = fieldPathSorted.begin(); + for (unsigned long i = 0; i < fieldPathSorted.size(); i++) { + ASSERT_EQ(*itr, array[i]); + ++itr; + } +} + +TEST(DependenciesToProjectionTest, PathLessThan) { + auto lessThan = PathPrefixComparator(); + ASSERT_FALSE(lessThan("a", "a")); + ASSERT_TRUE(lessThan("a", "aa")); + ASSERT_TRUE(lessThan("a", "b")); + ASSERT_TRUE(lessThan("", "a")); + ASSERT_TRUE(lessThan("Aa", "aa")); + ASSERT_TRUE(lessThan("a.b", "ab")); + ASSERT_TRUE(lessThan("a.b", "a-b")); // SERVER-66418 + ASSERT_TRUE(lessThan("a.b", "a b")); // SERVER-66418 + // verify the difference from the standard sort + ASSERT_TRUE(std::string("a.b") > std::string("a-b")); + ASSERT_TRUE(std::string("a.b") > std::string("a b")); + // test unicode behavior + ASSERT_TRUE(lessThan("a.b", "a🌲")); + ASSERT_TRUE(lessThan("a.b", "a🌲b")); + ASSERT_TRUE(lessThan("🌲", "🌳")); // U+1F332 < U+1F333 + ASSERT_TRUE(lessThan("🌲", "🌲.b")); + ASSERT_FALSE(lessThan("🌲.b", "🌲")); +} + } // namespace } // namespace mongo diff --git a/src/mongo/db/pipeline/document_source_densify.cpp b/src/mongo/db/pipeline/document_source_densify.cpp index 0325dbaa775..740ef193431 100644 --- a/src/mongo/db/pipeline/document_source_densify.cpp +++ b/src/mongo/db/pipeline/document_source_densify.cpp @@ -56,6 +56,10 @@ RangeStatement RangeStatement::parse(RangeSpec spec) { optional<TimeUnit> unit = [&]() { if (auto unit = spec.getUnit()) { + uassert(6586400, + "The step parameter in a range statement must be a whole number when " + "densifying a date range", + step.integral64Bit()); return optional<TimeUnit>(parseTimeUnit(unit.get())); } else { return optional<TimeUnit>(boost::none); @@ -275,8 +279,8 @@ DocumentSourceInternalDensify::DocGenerator::DocGenerator(DensifyValue min, // Extra checks for date step + unit. tassert(5733501, "Unit must be specified with a date step", _range.getUnit()); tassert(5733505, - "Step must be representable as an integer for date densification", - _range.getStep().integral()); + "Step must be a whole number for date densification", + _range.getStep().integral64Bit()); } else { tassert(5733506, "Unit must not be specified with non-date values", !_range.getUnit()); } @@ -877,7 +881,7 @@ DensifyValue DensifyValue::increment(const RangeStatement& range) const { }, [&](Date_t date) { return DensifyValue(dateAdd( - date, range.getUnit().value(), range.getStep().getDouble(), timezone())); + date, range.getUnit().value(), range.getStep().coerceToLong(), timezone())); }}, _value); } @@ -891,7 +895,7 @@ DensifyValue DensifyValue::decrement(const RangeStatement& range) const { }, [&](Date_t date) { return DensifyValue(dateAdd( - date, range.getUnit().value(), -range.getStep().getDouble(), timezone())); + date, range.getUnit().value(), -range.getStep().coerceToLong(), timezone())); }}, _value); } @@ -906,7 +910,7 @@ bool DensifyValue::isOnStepRelativeTo(DensifyValue base, RangeStatement range) c }, [&](Date_t date) { auto unit = range.getUnit().value(); - double step = range.getStep().getDouble(); + long long step = range.getStep().coerceToLong(); auto baseDate = base.getDate(); // Months, quarters and years have variable lengths depending on leap days diff --git a/src/mongo/db/pipeline/document_source_densify_test.cpp b/src/mongo/db/pipeline/document_source_densify_test.cpp index 79b16303eb9..c1a0218e9eb 100644 --- a/src/mongo/db/pipeline/document_source_densify_test.cpp +++ b/src/mongo/db/pipeline/document_source_densify_test.cpp @@ -374,7 +374,7 @@ DEATH_TEST(DensifyGeneratorTest, DateMinMustBeLessThanMax, "lower or equal to") 5733502); } -DEATH_TEST(DensifyGeneratorTest, DateStepMustBeInt, "integer") { +DEATH_TEST(DensifyGeneratorTest, DateStepMustBeInt, "whole number") { size_t counter = 0; ASSERT_THROWS_CODE(GenClass(makeDate("2021-01-01T00:00:00.000Z"), RangeStatement(Value(1.5), diff --git a/src/mongo/db/pipeline/expression.cpp b/src/mongo/db/pipeline/expression.cpp index cc3728c024f..bf0915deddd 100644 --- a/src/mongo/db/pipeline/expression.cpp +++ b/src/mongo/db/pipeline/expression.cpp @@ -39,6 +39,9 @@ #include <utility> #include <vector> +#include "mongo/bson/bsonmisc.h" +#include "mongo/bson/bsontypes.h" +#include "mongo/crypto/fle_crypto.h" #include "mongo/db/bson/dotted_path_support.h" #include "mongo/db/commands/feature_compatibility_version_documentation.h" #include "mongo/db/exec/document_value/document.h" @@ -46,6 +49,7 @@ #include "mongo/db/hasher.h" #include "mongo/db/jsobj.h" #include "mongo/db/pipeline/expression_context.h" +#include "mongo/db/pipeline/expression_parser_gen.h" #include "mongo/db/pipeline/variable_validation.h" #include "mongo/db/query/datetime/date_time_support.h" #include "mongo/db/query/sort_pattern.h" @@ -3742,6 +3746,123 @@ const char* ExpressionLog10::getOpName() const { return "$log10"; } +/* ----------------------- ExpressionInternalFLEEqual ---------------------------- */ +constexpr auto kInternalFleEq = "$_internalFleEq"_sd; + +ExpressionInternalFLEEqual::ExpressionInternalFLEEqual(ExpressionContext* const expCtx, + boost::intrusive_ptr<Expression> field, + ConstDataRange serverToken, + int64_t contentionFactor, + ConstDataRange edcToken) + : Expression(expCtx, {std::move(field)}), + _serverToken(PrfBlockfromCDR(serverToken)), + _edcToken(PrfBlockfromCDR(edcToken)), + _contentionFactor(contentionFactor) { + expCtx->sbeCompatible = false; + + auto tokens = + EDCServerCollection::generateEDCTokens(ConstDataRange(_edcToken), _contentionFactor); + + for (auto& token : tokens) { + _cachedEDCTokens.insert(std::move(token.data)); + } +} + +void ExpressionInternalFLEEqual::_doAddDependencies(DepsTracker* deps) const { + for (auto&& operand : _children) { + operand->addDependencies(deps); + } +} + +REGISTER_EXPRESSION_WITH_MIN_VERSION(_internalFleEq, + ExpressionInternalFLEEqual::parse, + AllowedWithApiStrict::kAlways, + AllowedWithClientType::kAny, + multiversion::FeatureCompatibilityVersion::kVersion_6_0); + +intrusive_ptr<Expression> ExpressionInternalFLEEqual::parse(ExpressionContext* const expCtx, + BSONElement expr, + const VariablesParseState& vps) { + + IDLParserErrorContext ctx(kInternalFleEq); + auto fleEq = InternalFleEqStruct::parse(ctx, expr.Obj()); + + auto fieldExpr = Expression::parseOperand(expCtx, fleEq.getField().getElement(), vps); + + auto serverTokenPair = fromEncryptedConstDataRange(fleEq.getServerEncryptionToken()); + + uassert(6672405, + "Invalid server token", + serverTokenPair.first == EncryptedBinDataType::kFLE2TransientRaw && + serverTokenPair.second.length() == sizeof(PrfBlock)); + + auto edcTokenPair = fromEncryptedConstDataRange(fleEq.getEdcDerivedToken()); + + uassert(6672406, + "Invalid edc token", + edcTokenPair.first == EncryptedBinDataType::kFLE2TransientRaw && + edcTokenPair.second.length() == sizeof(PrfBlock)); + + + auto cf = fleEq.getMaxCounter(); + uassert(6672408, "Contention factor must be between 0 and 10000", cf >= 0 && cf < 10000); + + return new ExpressionInternalFLEEqual(expCtx, + std::move(fieldExpr), + serverTokenPair.second, + fleEq.getMaxCounter(), + edcTokenPair.second); +} + +Value toValue(const std::array<std::uint8_t, 32>& buf) { + auto vec = toEncryptedVector(EncryptedBinDataType::kFLE2TransientRaw, buf); + return Value(BSONBinData(vec.data(), vec.size(), BinDataType::Encrypt)); +} + +Value ExpressionInternalFLEEqual::serialize(bool explain) const { + return Value(Document{{kInternalFleEq, + Document{{"field", _children[0]->serialize(explain)}, + {"edc", toValue(_edcToken)}, + {"counter", Value(static_cast<long long>(_contentionFactor))}, + {"server", toValue(_serverToken)}}}}); +} + +Value ExpressionInternalFLEEqual::evaluate(const Document& root, Variables* variables) const { + // Inputs + // 1. Value for FLE2IndexedEqualityEncryptedValue field + + Value fieldValue = _children[0]->evaluate(root, variables); + + if (fieldValue.nullish()) { + return Value(BSONNULL); + } + + if (fieldValue.getType() != BinData) { + return Value(false); + } + + auto fieldValuePair = fromEncryptedBinData(fieldValue); + + uassert(6672407, + "Invalid encrypted indexed field", + fieldValuePair.first == EncryptedBinDataType::kFLE2EqualityIndexedValue); + + // Value matches if + // 1. Decrypt field is successful + // 2. EDC_u Token is in GenTokens(EDC Token, ContentionFactor) + // + auto swIndexed = + EDCServerCollection::decryptAndParse(ConstDataRange(_serverToken), fieldValuePair.second); + uassertStatusOK(swIndexed); + auto indexed = swIndexed.getValue(); + + return Value(_cachedEDCTokens.count(indexed.edc.data) == 1); +} + +const char* ExpressionInternalFLEEqual::getOpName() const { + return kInternalFleEq.rawData(); +} + /* ------------------------ ExpressionNary ----------------------------- */ /** diff --git a/src/mongo/db/pipeline/expression.h b/src/mongo/db/pipeline/expression.h index fa59e89ecf2..9cce6d0b1e2 100644 --- a/src/mongo/db/pipeline/expression.h +++ b/src/mongo/db/pipeline/expression.h @@ -29,6 +29,7 @@ #pragma once +#include "mongo/base/data_range.h" #include "mongo/platform/basic.h" #include <algorithm> @@ -2197,6 +2198,38 @@ public: } }; +class ExpressionInternalFLEEqual final : public Expression { +public: + ExpressionInternalFLEEqual(ExpressionContext* expCtx, + boost::intrusive_ptr<Expression> field, + ConstDataRange serverToken, + int64_t contentionFactor, + ConstDataRange edcToken); + Value serialize(bool explain) const final; + + Value evaluate(const Document& root, Variables* variables) const final; + const char* getOpName() const; + + static boost::intrusive_ptr<Expression> parse(ExpressionContext* expCtx, + BSONElement expr, + const VariablesParseState& vps); + void _doAddDependencies(DepsTracker* deps) const final; + + void acceptVisitor(ExpressionMutableVisitor* visitor) final { + return visitor->visit(this); + } + + void acceptVisitor(ExpressionConstVisitor* visitor) const final { + return visitor->visit(this); + } + +private: + std::array<std::uint8_t, 32> _serverToken; + std::array<std::uint8_t, 32> _edcToken; + int64_t _contentionFactor; + stdx::unordered_set<std::array<std::uint8_t, 32>> _cachedEDCTokens; +}; + class ExpressionMap final : public Expression { public: ExpressionMap( diff --git a/src/mongo/db/pipeline/expression_parser.idl b/src/mongo/db/pipeline/expression_parser.idl new file mode 100644 index 00000000000..9f1cde70856 --- /dev/null +++ b/src/mongo/db/pipeline/expression_parser.idl @@ -0,0 +1,57 @@ +# Copyright (C) 2022-present MongoDB, Inc. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the Server Side Public License, version 1, +# as published by MongoDB, Inc. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# Server Side Public License for more details. +# +# You should have received a copy of the Server Side Public License +# along with this program. If not, see +# <http://www.mongodb.com/licensing/server-side-public-license>. +# +# As a special exception, the copyright holders give permission to link the +# code of portions of this program with the OpenSSL library under certain +# conditions as described in each individual source file and distribute +# linked combinations including the program with the OpenSSL library. You +# must comply with the Server Side Public License in all respects for +# all of the code used other than as permitted herein. If you modify file(s) +# with this exception, you may extend this exception to your version of the +# file(s), but you are not obligated to do so. If you do not wish to do so, +# delete this exception statement from your version. If you delete this +# exception statement from all source files in the program, then also delete +# it in the license file. + +global: + cpp_namespace: "mongo" + +imports: + - "mongo/idl/basic_types.idl" + +structs: + + InternalFleEqStruct: + description: "Struct for $_internalFleEq" + strict: true + fields: + field: + description: "Expression" + type: IDLAnyType + cpp_name: field + edc: + description: "EDCDerivedFromDataToken" + type: bindata_encrypt + cpp_name: edcDerivedToken + server: + description: "ServerDataEncryptionLevel1Token" + type: bindata_encrypt + cpp_name: serverEncryptionToken + counter: + description: "Queryable Encryption max counter" + type: long + cpp_name: maxCounter + + diff --git a/src/mongo/db/pipeline/expression_test.cpp b/src/mongo/db/pipeline/expression_test.cpp index 63eb5cc884c..fd6f1c3490e 100644 --- a/src/mongo/db/pipeline/expression_test.cpp +++ b/src/mongo/db/pipeline/expression_test.cpp @@ -169,6 +169,7 @@ void parseAndVerifyResults( ASSERT_VALUE_EQ(expr->evaluate({}, &expCtx.variables), expected); } + /* ------------------------- ExpressionArrayToObject -------------------------- */ TEST(ExpressionArrayToObjectTest, KVFormatSimple) { @@ -3715,4 +3716,240 @@ TEST(ExpressionCondTest, ConstantCondShouldOptimizeWithNonConstantBranches) { ASSERT_BSONOBJ_BINARY_EQ(expectedResult, expressionToBson(optimizedExprCond)); } +TEST(ExpressionFLETest, BadInputs) { + + auto expCtx = ExpressionContextForTest(); + auto vps = expCtx.variablesParseState; + { + auto expr = fromjson("{$_internalFleEq: 12}"); + ASSERT_THROWS_CODE(ExpressionInternalFLEEqual::parse(&expCtx, expr.firstElement(), vps), + DBException, + 10065); + } +} + +// Test we return true if it matches +TEST(ExpressionFLETest, TestBinData) { + auto expCtx = ExpressionContextForTest(); + auto vps = expCtx.variablesParseState; + + { + auto expr = fromjson(R"({$_internalFleEq: { + field: { + "$binary": { + "base64": + "BxI0VngSNJh2EjQSNFZ4kBIQ0JE8aMUFkPk5sSTVqfdNNfjqUfQQ1Uoj0BBcthrWoe9wyU3cN6zmWaQBPJ97t0ZPbecnMsU736yXre6cBO4Zdt/wThtY+v5+7vFgNnWpgRP0e+vam6QPmLvbBrO0LdsvAPTGW4yqwnzCIXCoEg7QPGfbfAXKPDTNenBfRlawiblmTOhO/6ljKotWsMp22q/rpHrn9IEIeJmecwuuPIJ7EA+XYQ3hOKVccYf2ogoK73+8xD/Vul83Qvr84Q8afc4QUMVs8A==", + "subType": "6" + } + }, + server: { + "$binary": { + "base64": "COuac/eRLYakKX6B0vZ1r3QodOQFfjqJD+xlGiPu4/Ps", + "subType": "6" + } + }, + counter: { + "$numberLong": "3" + }, + edc: { + "$binary": { + "base64": "CEWSmQID7SfwyAUI3ZkSFkATKryDQfnxXEOGad5d4Rsg", + "subType": "6" + } + } } })"); + auto exprFle = ExpressionInternalFLEEqual::parse(&expCtx, expr.firstElement(), vps); + + ASSERT_VALUE_EQ(exprFle->evaluate({}, &expCtx.variables), Value(true)); + } + + // Negative: Use wrong server token + { + auto expr = fromjson(R"({$_internalFleEq: { + field: { + "$binary": { + "base64": + "BxI0VngSNJh2EjQSNFZ4kBIQ0JE8aMUFkPk5sSTVqfdNNfjqUfQQ1Uoj0BBcthrWoe9wyU3cN6zmWaQBPJ97t0ZPbecnMsU736yXre6cBO4Zdt/wThtY+v5+7vFgNnWpgRP0e+vam6QPmLvbBrO0LdsvAPTGW4yqwnzCIXCoEg7QPGfbfAXKPDTNenBfRlawiblmTOhO/6ljKotWsMp22q/rpHrn9IEIeJmecwuuPIJ7EA+XYQ3hOKVccYf2ogoK73+8xD/Vul83Qvr84Q8afc4QUMVs8A==", + "subType": "6" + } + }, + server: { + "$binary": { + "base64": "COuac/eRLYakKX6B0vZ1r3QodOQFfjqJD+xlGiPu4/Ps", + "subType": "6" + } + }, + counter: { + "$numberLong": "3" + }, + edc: { + "$binary": { + "base64": "CEWSMQID7SFWYAUI3ZKSFKATKRYDQFNXXEOGAD5D4RSG", + "subType": "6" + } + } } })"); + auto exprFle = ExpressionInternalFLEEqual::parse(&expCtx, expr.firstElement(), vps); + + ASSERT_VALUE_EQ(exprFle->evaluate({}, &expCtx.variables), Value(false)); + } + + // Negative: Use wrong edc token + { + auto expr = fromjson(R"({$_internalFleEq: { + field: { + "$binary": { + "base64": + "BxI0VngSNJh2EjQSNFZ4kBIQ0JE8aMUFkPk5sSTVqfdNNfjqUfQQ1Uoj0BBcthrWoe9wyU3cN6zmWaQBPJ97t0ZPbecnMsU736yXre6cBO4Zdt/wThtY+v5+7vFgNnWpgRP0e+vam6QPmLvbBrO0LdsvAPTGW4yqwnzCIXCoEg7QPGfbfAXKPDTNenBfRlawiblmTOhO/6ljKotWsMp22q/rpHrn9IEIeJmecwuuPIJ7EA+XYQ3hOKVccYf2ogoK73+8xD/Vul83Qvr84Q8afc4QUMVs8A==", + "subType": "6" + } + }, + server: { + "$binary": { + "base64": "COUAC/ERLYAKKX6B0VZ1R3QODOQFFJQJD+XLGIPU4/PS", + "subType": "6" + } + }, + counter: { + "$numberLong": "3" + }, + edc: { + "$binary": { + "base64": "CEWSmQID7SfwyAUI3ZkSFkATKryDQfnxXEOGad5d4Rsg", + "subType": "6" + } + } } })"); + auto exprFle = ExpressionInternalFLEEqual::parse(&expCtx, expr.firstElement(), vps); + + ASSERT_THROWS_CODE( + exprFle->evaluate({}, &expCtx.variables), DBException, ErrorCodes::Overflow); + } +} + +TEST(ExpressionFLETest, TestBinData_ContentionFactor) { + auto expCtx = ExpressionContextForTest(); + auto vps = expCtx.variablesParseState; + + // Use the wrong contention factor - 0 + { + auto expr = fromjson(R"({$_internalFleEq: { + field: { + "$binary": { + "base64": + "BxI0VngSNJh2EjQSNFZ4kBIQ5+Wa5+SZafJeRUDGdLNx+i2ADDkyV2qA90Xcve7FqltoDm1PllSSgUS4fYtw3XDjzoNZrFFg8LfG2wH0HYbLMswv681KJpmEw7+RXy4CcPVFgoRFt24N13p7jT+pqu2oQAHAoxYTy/TsiAyY4RnAMiXYGg3hWz4AO/WxHNSyq6B6kX5d7x/hrXvppsZDc2Pmhd+c5xmovlv5RPj7wnNld13kYcMluztjNswiCH05hM/kp2/P7kw30iVnbz0SZxn1FjjCug==", + "subType": "6" + } + }, + server: { + "$binary": { + "base64": "COuac/eRLYakKX6B0vZ1r3QodOQFfjqJD+xlGiPu4/Ps", + "subType": "6" + } + }, + counter: { + "$numberLong": "0" + }, + edc: { + "$binary": { + "base64": "CEWSmQID7SfwyAUI3ZkSFkATKryDQfnxXEOGad5d4Rsg", + "subType": "6" + } + } } })"); + auto exprFle = ExpressionInternalFLEEqual::parse(&expCtx, expr.firstElement(), vps); + + ASSERT_VALUE_EQ(exprFle->evaluate({}, &expCtx.variables), Value(false)); + } + + // Use the right contention factor - 50 + { + auto expr = fromjson(R"({$_internalFleEq: { + field: { + "$binary": { + "base64": +"BxI0VngSNJh2EjQSNFZ4kBIQ5+Wa5+SZafJeRUDGdLNx+i2ADDkyV2qA90Xcve7FqltoDm1PllSSgUS4fYtw3XDjzoNZrFFg8LfG2wH0HYbLMswv681KJpmEw7+RXy4CcPVFgoRFt24N13p7jT+pqu2oQAHAoxYTy/TsiAyY4RnAMiXYGg3hWz4AO/WxHNSyq6B6kX5d7x/hrXvppsZDc2Pmhd+c5xmovlv5RPj7wnNld13kYcMluztjNswiCH05hM/kp2/P7kw30iVnbz0SZxn1FjjCug==", + "subType": "6" + } + }, + server: { + "$binary": { + "base64": "COuac/eRLYakKX6B0vZ1r3QodOQFfjqJD+xlGiPu4/Ps", + "subType": "6" + } + }, + counter: { + "$numberLong": "50" + }, + edc: { + "$binary": { + "base64": "CEWSmQID7SfwyAUI3ZkSFkATKryDQfnxXEOGad5d4Rsg", + "subType": "6" + } + } } })"); + auto exprFle = ExpressionInternalFLEEqual::parse(&expCtx, expr.firstElement(), vps); + + ASSERT_VALUE_EQ(exprFle->evaluate({}, &expCtx.variables), Value(true)); + } +} + +TEST(ExpressionFLETest, TestBinData_RoundTrip) { + auto expCtx = ExpressionContextForTest(); + auto vps = expCtx.variablesParseState; + + auto expr = fromjson(R"({$_internalFleEq: { + field: { + "$binary": { + "base64": + "BxI0VngSNJh2EjQSNFZ4kBIQ0JE8aMUFkPk5sSTVqfdNNfjqUfQQ1Uoj0BBcthrWoe9wyU3cN6zmWaQBPJ97t0ZPbecnMsU736yXre6cBO4Zdt/wThtY+v5+7vFgNnWpgRP0e+vam6QPmLvbBrO0LdsvAPTGW4yqwnzCIXCoEg7QPGfbfAXKPDTNenBfRlawiblmTOhO/6ljKotWsMp22q/rpHrn9IEIeJmecwuuPIJ7EA+XYQ3hOKVccYf2ogoK73+8xD/Vul83Qvr84Q8afc4QUMVs8A==", + "subType": "6" + } + }, + server: { + "$binary": { + "base64": "COuac/eRLYakKX6B0vZ1r3QodOQFfjqJD+xlGiPu4/Ps", + "subType": "6" + } + }, + counter: { + "$numberLong": "3" + }, + edc: { + "$binary": { + "base64": "CEWSmQID7SfwyAUI3ZkSFkATKryDQfnxXEOGad5d4Rsg", + "subType": "6" + } + } } })"); + auto exprFle = ExpressionInternalFLEEqual::parse(&expCtx, expr.firstElement(), vps); + + ASSERT_VALUE_EQ(exprFle->evaluate({}, &expCtx.variables), Value(true)); + + // Verify it round trips + auto value = exprFle->serialize(false); + + auto roundTripExpr = fromjson(R"({$_internalFleEq: { + field: { + "$const" : { "$binary": { + "base64": + "BxI0VngSNJh2EjQSNFZ4kBIQ0JE8aMUFkPk5sSTVqfdNNfjqUfQQ1Uoj0BBcthrWoe9wyU3cN6zmWaQBPJ97t0ZPbecnMsU736yXre6cBO4Zdt/wThtY+v5+7vFgNnWpgRP0e+vam6QPmLvbBrO0LdsvAPTGW4yqwnzCIXCoEg7QPGfbfAXKPDTNenBfRlawiblmTOhO/6ljKotWsMp22q/rpHrn9IEIeJmecwuuPIJ7EA+XYQ3hOKVccYf2ogoK73+8xD/Vul83Qvr84Q8afc4QUMVs8A==", + "subType": "6" + }} + }, + edc: { + "$binary": { + "base64": "CEWSmQID7SfwyAUI3ZkSFkATKryDQfnxXEOGad5d4Rsg", + "subType": "6" + } + }, + counter: { + "$numberLong": "3" + }, + server: { + "$binary": { + "base64": "COuac/eRLYakKX6B0vZ1r3QodOQFfjqJD+xlGiPu4/Ps", + "subType": "6" + } + } + } })"); + + + ASSERT_BSONOBJ_EQ(value.getDocument().toBson(), roundTripExpr); +} + } // namespace ExpressionTests diff --git a/src/mongo/db/pipeline/expression_visitor.h b/src/mongo/db/pipeline/expression_visitor.h index 46ad3ee6295..6b7c4fc4cdd 100644 --- a/src/mongo/db/pipeline/expression_visitor.h +++ b/src/mongo/db/pipeline/expression_visitor.h @@ -153,6 +153,7 @@ class ExpressionHyperbolicSine; class ExpressionInternalFindSlice; class ExpressionInternalFindPositional; class ExpressionInternalFindElemMatch; +class ExpressionInternalFLEEqual; class ExpressionInternalJsEmit; class ExpressionFunction; class ExpressionDegreesToRadians; @@ -245,6 +246,7 @@ public: virtual void visit(expression_walker::MaybeConstPtr<IsConst, ExpressionLn>) = 0; virtual void visit(expression_walker::MaybeConstPtr<IsConst, ExpressionLog>) = 0; virtual void visit(expression_walker::MaybeConstPtr<IsConst, ExpressionLog10>) = 0; + virtual void visit(expression_walker::MaybeConstPtr<IsConst, ExpressionInternalFLEEqual>) = 0; virtual void visit(expression_walker::MaybeConstPtr<IsConst, ExpressionMap>) = 0; virtual void visit(expression_walker::MaybeConstPtr<IsConst, ExpressionMeta>) = 0; virtual void visit(expression_walker::MaybeConstPtr<IsConst, ExpressionMod>) = 0; @@ -424,6 +426,7 @@ struct SelectiveConstExpressionVisitorBase : public ExpressionConstVisitor { void visit(const ExpressionLn*) override {} void visit(const ExpressionLog*) override {} void visit(const ExpressionLog10*) override {} + void visit(const ExpressionInternalFLEEqual*) override {} void visit(const ExpressionMap*) override {} void visit(const ExpressionMeta*) override {} void visit(const ExpressionMod*) override {} diff --git a/src/mongo/db/pipeline/pipeline_d.cpp b/src/mongo/db/pipeline/pipeline_d.cpp index 85c3e82839a..b21e8635c51 100644 --- a/src/mongo/db/pipeline/pipeline_d.cpp +++ b/src/mongo/db/pipeline/pipeline_d.cpp @@ -1451,7 +1451,9 @@ PipelineD::buildInnerQueryExecutorGeneric(const MultipleCollectionAccessor& coll // This produces {$const: maxBucketSpanSeconds} make_intrusive<ExpressionConstant>( expCtx.get(), - Value{unpack->getBucketMaxSpanSeconds() * 1000}))), + Value{static_cast<long long>( + unpack->getBucketMaxSpanSeconds()) * + 1000}))), expCtx); pipeline->_sources.insert( unpackIter, diff --git a/src/mongo/db/query/fle/server_rewrite.cpp b/src/mongo/db/query/fle/server_rewrite.cpp index 185a979ec3d..1660c368093 100644 --- a/src/mongo/db/query/fle/server_rewrite.cpp +++ b/src/mongo/db/query/fle/server_rewrite.cpp @@ -27,11 +27,13 @@ * it in the license file. */ +#define MONGO_LOGV2_DEFAULT_COMPONENT ::mongo::logv2::LogComponent::kQuery #include "mongo/db/query/fle/server_rewrite.h" #include <memory> +#include "mongo/bson/bsonmisc.h" #include "mongo/bson/bsonobj.h" #include "mongo/bson/bsonobjbuilder.h" #include "mongo/bson/bsontypes.h" @@ -48,9 +50,11 @@ #include "mongo/db/pipeline/expression.h" #include "mongo/db/query/collation/collator_factory_interface.h" #include "mongo/db/service_context.h" +#include "mongo/logv2/log.h" #include "mongo/s/grid.h" #include "mongo/s/transaction_router_resource_yielder.h" #include "mongo/util/assert_util.h" +#include "mongo/util/intrusive_counter.h" namespace mongo::fle { @@ -68,6 +72,56 @@ std::unique_ptr<CollatorInterface> collatorFromBSON(OperationContext* opCtx, } namespace { +template <typename PayloadT> +boost::intrusive_ptr<ExpressionInternalFLEEqual> generateFleEqualMatch(StringData path, + const PayloadT& ffp, + ExpressionContext* expCtx) { + // Generate { $_internalFleEq: { field: "$field_name", server: f_3, counter: cm, edc: k_EDC] } + auto tokens = ParsedFindPayload(ffp); + + uassert(6672401, + "Missing required field server encryption token in find payload", + tokens.serverToken.has_value()); + + return make_intrusive<ExpressionInternalFLEEqual>( + expCtx, + ExpressionFieldPath::createPathFromString( + expCtx, path.toString(), expCtx->variablesParseState), + tokens.serverToken.get().data, + tokens.maxCounter.value_or(0LL), + tokens.edcToken.data); +} + + +template <typename PayloadT> +std::unique_ptr<ExpressionInternalFLEEqual> generateFleEqualMatchUnique(StringData path, + const PayloadT& ffp, + ExpressionContext* expCtx) { + // Generate { $_internalFleEq: { field: "$field_name", server: f_3, counter: cm, edc: k_EDC] } + auto tokens = ParsedFindPayload(ffp); + + uassert(6672419, + "Missing required field server encryption token in find payload", + tokens.serverToken.has_value()); + + return std::make_unique<ExpressionInternalFLEEqual>( + expCtx, + ExpressionFieldPath::createPathFromString( + expCtx, path.toString(), expCtx->variablesParseState), + tokens.serverToken.get().data, + tokens.maxCounter.value_or(0LL), + tokens.edcToken.data); +} + +std::unique_ptr<MatchExpression> generateFleEqualMatchAndExpr(StringData path, + const BSONElement ffp, + ExpressionContext* expCtx) { + auto fleEqualMatch = generateFleEqualMatch(path, ffp, expCtx); + + return std::make_unique<ExprMatchExpression>(fleEqualMatch, expCtx); +} + + /** * This section defines a mapping from DocumentSources to the dispatch function to appropriately * handle FLE rewriting for that stage. This should be kept in line with code on the client-side @@ -128,7 +182,8 @@ public: * The final output will look like * {$or: [{$in: [tag0, "$__safeContent__"]}, {$in: [tag1, "$__safeContent__"]}, ...]}. */ - std::unique_ptr<Expression> rewriteComparisonsToEncryptedField( + std::unique_ptr<Expression> rewriteInToEncryptedField( + const Expression* leftExpr, const std::vector<boost::intrusive_ptr<Expression>>& equalitiesList) { size_t numFFPs = 0; std::vector<boost::intrusive_ptr<Expression>> orListElems; @@ -140,11 +195,122 @@ public: continue; } - // ... rewrite the payload to a list of tags... numFFPs++; + } + } + + // Finally, construct an $or of all of the $ins. + if (numFFPs == 0) { + return nullptr; + } + + uassert( + 6334102, + "If any elements in an comparison expression are encrypted, then all elements should " + "be encrypted.", + numFFPs == equalitiesList.size()); + + auto leftFieldPath = dynamic_cast<const ExpressionFieldPath*>(leftExpr); + uassert(6672417, + "$in is only supported with Queryable Encryption when the first argument is a " + "field path", + leftFieldPath != nullptr); + + if (!queryRewriter->isForceHighCardinality()) { + try { + for (auto& equality : equalitiesList) { + // For each expression representing a FleFindPayload... + if (auto constChild = dynamic_cast<ExpressionConstant*>(equality.get())) { + // ... rewrite the payload to a list of tags... + auto tags = queryRewriter->rewritePayloadAsTags(constChild->getValue()); + for (auto&& tagElt : tags) { + // ... and for each tag, construct expression {$in: [tag, + // "$__safeContent__"]}. + std::vector<boost::intrusive_ptr<Expression>> inVec{ + ExpressionConstant::create(queryRewriter->expCtx(), tagElt), + ExpressionFieldPath::createPathFromString( + queryRewriter->expCtx(), + kSafeContent, + queryRewriter->expCtx()->variablesParseState)}; + orListElems.push_back(make_intrusive<ExpressionIn>( + queryRewriter->expCtx(), std::move(inVec))); + } + } + } + + didRewrite = true; + + return std::make_unique<ExpressionOr>(queryRewriter->expCtx(), + std::move(orListElems)); + } catch (const ExceptionFor<ErrorCodes::FLEMaxTagLimitExceeded>& ex) { + LOGV2_DEBUG(6672403, + 2, + "FLE Max tag limit hit during aggregation $in rewrite", + "__error__"_attr = ex.what()); + + if (queryRewriter->getHighCardinalityMode() != + FLEQueryRewriter::HighCardinalityMode::kUseIfNeeded) { + throw; + } + + // fall through + } + } + + for (auto& equality : equalitiesList) { + if (auto constChild = dynamic_cast<ExpressionConstant*>(equality.get())) { + auto fleEqExpr = generateFleEqualMatch( + leftFieldPath->getFieldPathWithoutCurrentPrefix().fullPath(), + constChild->getValue(), + queryRewriter->expCtx()); + orListElems.push_back(fleEqExpr); + } + } + + didRewrite = true; + return std::make_unique<ExpressionOr>(queryRewriter->expCtx(), std::move(orListElems)); + } + + // Rewrite a [$eq : [$fieldpath, constant]] or [$eq: [constant, $fieldpath]] + // to _internalFleEq: {field: $fieldpath, edc: edcToken, counter: N, server: serverToken} + std::unique_ptr<Expression> rewriteComparisonsToEncryptedField( + const std::vector<boost::intrusive_ptr<Expression>>& equalitiesList) { + + auto leftConstant = dynamic_cast<ExpressionConstant*>(equalitiesList[0].get()); + auto rightConstant = dynamic_cast<ExpressionConstant*>(equalitiesList[1].get()); + + bool isLeftFFP = leftConstant && queryRewriter->isFleFindPayload(leftConstant->getValue()); + bool isRightFFP = + rightConstant && queryRewriter->isFleFindPayload(rightConstant->getValue()); + + uassert(6334100, + "Cannot compare two encrypted constants to each other", + !(isLeftFFP && isRightFFP)); + + // No FLE Find Payload + if (!isLeftFFP && !isRightFFP) { + return nullptr; + } + + auto leftFieldPath = dynamic_cast<ExpressionFieldPath*>(equalitiesList[0].get()); + auto rightFieldPath = dynamic_cast<ExpressionFieldPath*>(equalitiesList[1].get()); + + uassert( + 6672413, + "Queryable Encryption only supports comparisons between a field path and a constant", + leftFieldPath || rightFieldPath); + + auto fieldPath = leftFieldPath ? leftFieldPath : rightFieldPath; + auto constChild = isLeftFFP ? leftConstant : rightConstant; + + if (!queryRewriter->isForceHighCardinality()) { + try { + std::vector<boost::intrusive_ptr<Expression>> orListElems; + auto tags = queryRewriter->rewritePayloadAsTags(constChild->getValue()); for (auto&& tagElt : tags) { - // ... and for each tag, construct expression {$in: [tag, "$__safeContent__"]}. + // ... and for each tag, construct expression {$in: [tag, + // "$__safeContent__"]}. std::vector<boost::intrusive_ptr<Expression>> inVec{ ExpressionConstant::create(queryRewriter->expCtx(), tagElt), ExpressionFieldPath::createPathFromString( @@ -154,21 +320,33 @@ public: orListElems.push_back( make_intrusive<ExpressionIn>(queryRewriter->expCtx(), std::move(inVec))); } + + didRewrite = true; + return std::make_unique<ExpressionOr>(queryRewriter->expCtx(), + std::move(orListElems)); + + } catch (const ExceptionFor<ErrorCodes::FLEMaxTagLimitExceeded>& ex) { + LOGV2_DEBUG(6672409, + 2, + "FLE Max tag limit hit during query $in rewrite", + "__error__"_attr = ex.what()); + + if (queryRewriter->getHighCardinalityMode() != + FLEQueryRewriter::HighCardinalityMode::kUseIfNeeded) { + throw; + } + + // fall through } } - // Finally, construct an $or of all of the $ins. - if (numFFPs == 0) { - return nullptr; - } - uassert( - 6334102, - "If any elements in an comparison expression are encrypted, then all elements should " - "be encrypted.", - numFFPs == equalitiesList.size()); + auto fleEqExpr = + generateFleEqualMatchUnique(fieldPath->getFieldPathWithoutCurrentPrefix().fullPath(), + constChild->getValue(), + queryRewriter->expCtx()); didRewrite = true; - return std::make_unique<ExpressionOr>(queryRewriter->expCtx(), std::move(orListElems)); + return fleEqExpr; } std::unique_ptr<Expression> postVisit(Expression* exp) { @@ -177,30 +355,28 @@ public: // ignored when rewrites are done; there is no extra information in that child that // doesn't exist in the FFPs in the $in list. if (auto inList = dynamic_cast<ExpressionArray*>(inExpr->getOperandList()[1].get())) { - return rewriteComparisonsToEncryptedField(inList->getChildren()); + return rewriteInToEncryptedField(inExpr->getOperandList()[0].get(), + inList->getChildren()); } } else if (auto eqExpr = dynamic_cast<ExpressionCompare*>(exp); eqExpr && (eqExpr->getOp() == ExpressionCompare::EQ || eqExpr->getOp() == ExpressionCompare::NE)) { // Rewrite an $eq comparing an encrypted field and an encrypted constant to an $or. - // Either child may be the constant, so try rewriting both. - auto or0 = rewriteComparisonsToEncryptedField({eqExpr->getChildren()[0]}); - auto or1 = rewriteComparisonsToEncryptedField({eqExpr->getChildren()[1]}); - uassert(6334100, "Cannot compare two encrypted constants to each other", !or0 || !or1); + auto newExpr = rewriteComparisonsToEncryptedField(eqExpr->getChildren()); // Neither child is an encrypted constant, and no rewriting needs to be done. - if (!or0 && !or1) { + if (!newExpr) { return nullptr; } // Exactly one child was an encrypted constant. The other child can be ignored; there is // no extra information in that child that doesn't exist in the FFP. if (eqExpr->getOp() == ExpressionCompare::NE) { - std::vector<boost::intrusive_ptr<Expression>> notChild{(or0 ? or0 : or1).release()}; + std::vector<boost::intrusive_ptr<Expression>> notChild{newExpr.release()}; return std::make_unique<ExpressionNot>(queryRewriter->expCtx(), std::move(notChild)); } - return std::move(or0 ? or0 : or1); + return newExpr; } return nullptr; @@ -213,11 +389,14 @@ public: BSONObj rewriteEncryptedFilter(const FLEStateCollectionReader& escReader, const FLEStateCollectionReader& eccReader, boost::intrusive_ptr<ExpressionContext> expCtx, - BSONObj filter) { + BSONObj filter, + HighCardinalityModeAllowed mode) { + if (auto rewritten = - FLEQueryRewriter(expCtx, escReader, eccReader).rewriteMatchExpression(filter)) { + FLEQueryRewriter(expCtx, escReader, eccReader, mode).rewriteMatchExpression(filter)) { return rewritten.get(); } + return filter; } @@ -273,16 +452,18 @@ public: FilterRewrite(boost::intrusive_ptr<ExpressionContext> expCtx, const NamespaceString& nss, const EncryptionInformation& encryptInfo, - const BSONObj toRewrite) - : RewriteBase(expCtx, nss, encryptInfo), userFilter(toRewrite) {} + const BSONObj toRewrite, + HighCardinalityModeAllowed mode) + : RewriteBase(expCtx, nss, encryptInfo), userFilter(toRewrite), _mode(mode) {} ~FilterRewrite(){}; void doRewrite(FLEStateCollectionReader& escReader, FLEStateCollectionReader& eccReader) final { - rewrittenFilter = rewriteEncryptedFilter(escReader, eccReader, expCtx, userFilter); + rewrittenFilter = rewriteEncryptedFilter(escReader, eccReader, expCtx, userFilter, _mode); } const BSONObj userFilter; BSONObj rewrittenFilter; + HighCardinalityModeAllowed _mode; }; // This helper executes the rewrite(s) inside a transaction. The transaction runs in a separate @@ -324,7 +505,8 @@ BSONObj rewriteEncryptedFilterInsideTxn(FLEQueryInterface* queryImpl, StringData db, const EncryptedFieldConfig& efc, boost::intrusive_ptr<ExpressionContext> expCtx, - BSONObj filter) { + BSONObj filter, + HighCardinalityModeAllowed mode) { auto makeCollectionReader = [&](FLEQueryInterface* queryImpl, const StringData& coll) { NamespaceString nss(db, coll); auto docCount = queryImpl->countDocuments(nss); @@ -332,7 +514,8 @@ BSONObj rewriteEncryptedFilterInsideTxn(FLEQueryInterface* queryImpl, }; auto escReader = makeCollectionReader(queryImpl, efc.getEscCollection().get()); auto eccReader = makeCollectionReader(queryImpl, efc.getEccCollection().get()); - return rewriteEncryptedFilter(escReader, eccReader, expCtx, filter); + + return rewriteEncryptedFilter(escReader, eccReader, expCtx, filter, mode); } BSONObj rewriteQuery(OperationContext* opCtx, @@ -340,8 +523,9 @@ BSONObj rewriteQuery(OperationContext* opCtx, const NamespaceString& nss, const EncryptionInformation& info, BSONObj filter, - GetTxnCallback getTransaction) { - auto sharedBlock = std::make_shared<FilterRewrite>(expCtx, nss, info, filter); + GetTxnCallback getTransaction, + HighCardinalityModeAllowed mode) { + auto sharedBlock = std::make_shared<FilterRewrite>(expCtx, nss, info, filter, mode); doFLERewriteInTxn(opCtx, sharedBlock, getTransaction); return sharedBlock->rewrittenFilter.getOwned(); } @@ -365,7 +549,8 @@ void processFindCommand(OperationContext* opCtx, nss, findCommand->getEncryptionInformation().get(), findCommand->getFilter().getOwned(), - getTransaction)); + getTransaction, + HighCardinalityModeAllowed::kAllow)); // The presence of encryptionInformation is a signal that this is a FLE request that requires // special processing. Once we've rewritten the query, it's no longer a "special" FLE query, but // a normal query that can be executed by the query system like any other, so remove @@ -389,7 +574,8 @@ void processCountCommand(OperationContext* opCtx, nss, countCommand->getEncryptionInformation().get(), countCommand->getQuery().getOwned(), - getTxn)); + getTxn, + HighCardinalityModeAllowed::kAllow)); // The presence of encryptionInformation is a signal that this is a FLE request that requires // special processing. Once we've rewritten the query, it's no longer a "special" FLE query, but // a normal query that can be executed by the query system like any other, so remove @@ -503,59 +689,112 @@ std::vector<Value> FLEQueryRewriter::rewritePayloadAsTags(Value fleFindPayload) return tagVec; } -std::unique_ptr<InMatchExpression> FLEQueryRewriter::rewriteEq( - const EqualityMatchExpression* expr) { + +std::unique_ptr<MatchExpression> FLEQueryRewriter::rewriteEq(const EqualityMatchExpression* expr) { auto ffp = expr->getData(); if (!isFleFindPayload(ffp)) { return nullptr; } - auto obj = rewritePayloadAsTags(ffp); - - auto tags = std::vector<BSONElement>(); - obj.elems(tags); + if (_mode != HighCardinalityMode::kForceAlways) { + try { + auto obj = rewritePayloadAsTags(ffp); + + auto tags = std::vector<BSONElement>(); + obj.elems(tags); + + auto inExpr = std::make_unique<InMatchExpression>(kSafeContent); + inExpr->setBackingBSON(std::move(obj)); + auto status = inExpr->setEqualities(std::move(tags)); + uassertStatusOK(status); + _rewroteLastExpression = true; + return inExpr; + } catch (const ExceptionFor<ErrorCodes::FLEMaxTagLimitExceeded>& ex) { + LOGV2_DEBUG(6672410, + 2, + "FLE Max tag limit hit during query $eq rewrite", + "__error__"_attr = ex.what()); + + if (_mode != HighCardinalityMode::kUseIfNeeded) { + throw; + } - auto inExpr = std::make_unique<InMatchExpression>(kSafeContent); - inExpr->setBackingBSON(std::move(obj)); - auto status = inExpr->setEqualities(std::move(tags)); - uassertStatusOK(status); + // fall through + } + } + auto exprMatch = generateFleEqualMatchAndExpr(expr->path(), ffp, _expCtx.get()); _rewroteLastExpression = true; - return inExpr; + return exprMatch; } -std::unique_ptr<InMatchExpression> FLEQueryRewriter::rewriteIn(const InMatchExpression* expr) { - auto backingBSONBuilder = BSONArrayBuilder(); +std::unique_ptr<MatchExpression> FLEQueryRewriter::rewriteIn(const InMatchExpression* expr) { size_t numFFPs = 0; for (auto& eq : expr->getEqualities()) { if (isFleFindPayload(eq)) { - auto obj = rewritePayloadAsTags(eq); ++numFFPs; - for (auto&& elt : obj) { - backingBSONBuilder.append(elt); - } } } + if (numFFPs == 0) { return nullptr; } + // All elements in an encrypted $in expression should be FFPs. uassert( 6329400, "If any elements in a $in expression are encrypted, then all elements should be encrypted.", numFFPs == expr->getEqualities().size()); - auto backingBSON = backingBSONBuilder.arr(); - auto allTags = std::vector<BSONElement>(); - backingBSON.elems(allTags); + if (_mode != HighCardinalityMode::kForceAlways) { + + try { + auto backingBSONBuilder = BSONArrayBuilder(); + + for (auto& eq : expr->getEqualities()) { + auto obj = rewritePayloadAsTags(eq); + for (auto&& elt : obj) { + backingBSONBuilder.append(elt); + } + } + + auto backingBSON = backingBSONBuilder.arr(); + auto allTags = std::vector<BSONElement>(); + backingBSON.elems(allTags); + + auto inExpr = std::make_unique<InMatchExpression>(kSafeContent); + inExpr->setBackingBSON(std::move(backingBSON)); + auto status = inExpr->setEqualities(std::move(allTags)); + uassertStatusOK(status); + + _rewroteLastExpression = true; + return inExpr; + + } catch (const ExceptionFor<ErrorCodes::FLEMaxTagLimitExceeded>& ex) { + LOGV2_DEBUG(6672411, + 2, + "FLE Max tag limit hit during query $in rewrite", + "__error__"_attr = ex.what()); + + if (_mode != HighCardinalityMode::kUseIfNeeded) { + throw; + } + + // fall through + } + } + + std::vector<std::unique_ptr<MatchExpression>> matches; + matches.reserve(numFFPs); - auto inExpr = std::make_unique<InMatchExpression>(kSafeContent); - inExpr->setBackingBSON(std::move(backingBSON)); - auto status = inExpr->setEqualities(std::move(allTags)); - uassertStatusOK(status); + for (auto& eq : expr->getEqualities()) { + auto exprMatch = generateFleEqualMatchAndExpr(expr->path(), eq, _expCtx.get()); + matches.push_back(std::move(exprMatch)); + } + auto orExpr = std::make_unique<OrMatchExpression>(std::move(matches)); _rewroteLastExpression = true; - return inExpr; + return orExpr; } } // namespace mongo::fle diff --git a/src/mongo/db/query/fle/server_rewrite.h b/src/mongo/db/query/fle/server_rewrite.h index c71d1f01392..bf02eeebd4e 100644 --- a/src/mongo/db/query/fle/server_rewrite.h +++ b/src/mongo/db/query/fle/server_rewrite.h @@ -31,7 +31,7 @@ #include <memory> -#include "boost/smart_ptr/intrusive_ptr.hpp" +#include <boost/smart_ptr/intrusive_ptr.hpp> #include "mongo/bson/bsonobj.h" #include "mongo/crypto/fle_crypto.h" @@ -47,6 +47,14 @@ class FLEQueryInterface; namespace fle { /** + * Low Selectivity rewrites use $expr which is not supported in all commands such as upserts. + */ +enum class HighCardinalityModeAllowed { + kAllow, + kDisallow, +}; + +/** * Make a collator object from its BSON representation. Useful when creating ExpressionContext * objects for parsing MatchExpressions as part of the server-side rewrite. */ @@ -62,7 +70,8 @@ BSONObj rewriteQuery(OperationContext* opCtx, const NamespaceString& nss, const EncryptionInformation& info, BSONObj filter, - GetTxnCallback getTransaction); + GetTxnCallback getTransaction, + HighCardinalityModeAllowed mode); /** * Process a find command with encryptionInformation in-place, rewriting the filter condition so @@ -100,11 +109,13 @@ std::unique_ptr<Pipeline, PipelineDeleter> processPipeline( * from inside an existing transaction using a FLEQueryInterface constructed from a * transaction client. */ -BSONObj rewriteEncryptedFilterInsideTxn(FLEQueryInterface* queryImpl, - StringData db, - const EncryptedFieldConfig& efc, - boost::intrusive_ptr<ExpressionContext> expCtx, - BSONObj filter); +BSONObj rewriteEncryptedFilterInsideTxn( + FLEQueryInterface* queryImpl, + StringData db, + const EncryptedFieldConfig& efc, + boost::intrusive_ptr<ExpressionContext> expCtx, + BSONObj filter, + HighCardinalityModeAllowed mode = HighCardinalityModeAllowed::kDisallow); /** * Class which handles rewriting filter MatchExpressions for FLE2. The functionality is encapsulated @@ -116,14 +127,37 @@ BSONObj rewriteEncryptedFilterInsideTxn(FLEQueryInterface* queryImpl, */ class FLEQueryRewriter { public: + enum class HighCardinalityMode { + // Always use high cardinality filters, used by tests + kForceAlways, + + // Use high cardinality mode if $in rewrites do not fit in the + // internalQueryFLERewriteMemoryLimit memory limit + kUseIfNeeded, + + // Do not rewrite into high cardinality filter, throw exceptions instead + // Some contexts like upsert do not support $expr + kDisallow, + }; + /** * Takes in references to collection readers for the ESC and ECC that are used during tag * computation. */ FLEQueryRewriter(boost::intrusive_ptr<ExpressionContext> expCtx, const FLEStateCollectionReader& escReader, - const FLEStateCollectionReader& eccReader) + const FLEStateCollectionReader& eccReader, + HighCardinalityModeAllowed mode = HighCardinalityModeAllowed::kAllow) : _expCtx(expCtx), _escReader(&escReader), _eccReader(&eccReader) { + + if (internalQueryFLEAlwaysUseHighCardinalityMode.load()) { + _mode = HighCardinalityMode::kForceAlways; + } + + if (mode == HighCardinalityModeAllowed::kDisallow) { + _mode = HighCardinalityMode::kDisallow; + } + // This isn't the "real" query so we don't want to increment Expression // counters here. _expCtx->stopExpressionCounters(); @@ -184,6 +218,18 @@ public: return _expCtx.get(); } + bool isForceHighCardinality() const { + return _mode == HighCardinalityMode::kForceAlways; + } + + void setForceHighCardinalityForTest() { + _mode = HighCardinalityMode::kForceAlways; + } + + HighCardinalityMode getHighCardinalityMode() const { + return _mode; + } + protected: // This constructor should only be used for mocks in testing. FLEQueryRewriter(boost::intrusive_ptr<ExpressionContext> expCtx) @@ -196,8 +242,8 @@ private: std::unique_ptr<MatchExpression> _rewrite(MatchExpression* me); virtual BSONObj rewritePayloadAsTags(BSONElement fleFindPayload) const; - std::unique_ptr<InMatchExpression> rewriteEq(const EqualityMatchExpression* expr); - std::unique_ptr<InMatchExpression> rewriteIn(const InMatchExpression* expr); + std::unique_ptr<MatchExpression> rewriteEq(const EqualityMatchExpression* expr); + std::unique_ptr<MatchExpression> rewriteIn(const InMatchExpression* expr); boost::intrusive_ptr<ExpressionContext> _expCtx; @@ -207,7 +253,10 @@ private: const FLEStateCollectionReader* _eccReader; // True if the last Expression or MatchExpression processed by this rewriter was rewritten. - bool _rewroteLastExpression; + bool _rewroteLastExpression = false; + + // Controls how query rewriter rewrites the query + HighCardinalityMode _mode{HighCardinalityMode::kUseIfNeeded}; }; diff --git a/src/mongo/db/query/fle/server_rewrite_test.cpp b/src/mongo/db/query/fle/server_rewrite_test.cpp index cb81656dcb6..034de8f0aa9 100644 --- a/src/mongo/db/query/fle/server_rewrite_test.cpp +++ b/src/mongo/db/query/fle/server_rewrite_test.cpp @@ -31,7 +31,9 @@ #include <memory> #include "mongo/bson/bsonelement.h" +#include "mongo/bson/bsonmisc.h" #include "mongo/bson/bsonobjbuilder.h" +#include "mongo/bson/bsontypes.h" #include "mongo/db/matcher/expression_leaf.h" #include "mongo/db/pipeline/expression_context_for_test.h" #include "mongo/db/query/fle/server_rewrite.h" @@ -42,9 +44,19 @@ namespace mongo { namespace { -class MockFLEQueryRewriter : public fle::FLEQueryRewriter { +class BasicMockFLEQueryRewriter : public fle::FLEQueryRewriter { public: - MockFLEQueryRewriter() : fle::FLEQueryRewriter(new ExpressionContextForTest()), _tags() {} + BasicMockFLEQueryRewriter() : fle::FLEQueryRewriter(new ExpressionContextForTest()) {} + + BSONObj rewriteMatchExpressionForTest(const BSONObj& obj) { + auto res = rewriteMatchExpression(obj); + return res ? res.get() : obj; + } +}; + +class MockFLEQueryRewriter : public BasicMockFLEQueryRewriter { +public: + MockFLEQueryRewriter() : _tags() {} bool isFleFindPayload(const BSONElement& fleFindPayload) const override { return _encryptedFields.find(fleFindPayload.fieldNameStringData()) != @@ -56,11 +68,6 @@ public: _tags[fieldvalue] = tags; } - BSONObj rewriteMatchExpressionForTest(const BSONObj& obj) { - auto res = rewriteMatchExpression(obj); - return res ? res.get() : obj; - } - private: BSONObj rewritePayloadAsTags(BSONElement fleFindPayload) const override { ASSERT(fleFindPayload.isNumber()); // Only accept numbers as mock FFPs. @@ -72,6 +79,7 @@ private: std::map<std::pair<StringData, int>, BSONObj> _tags; std::set<StringData> _encryptedFields; }; + class FLEServerRewriteTest : public unittest::Test { public: FLEServerRewriteTest() {} @@ -361,5 +369,290 @@ TEST_F(FLEServerRewriteTest, ComparisonToObjectIgnored) { } } +template <typename T> +std::vector<uint8_t> toEncryptedVector(EncryptedBinDataType dt, T t) { + BSONObj obj = t.toBSON(); + + std::vector<uint8_t> buf(obj.objsize() + 1); + buf[0] = static_cast<uint8_t>(dt); + + std::copy(obj.objdata(), obj.objdata() + obj.objsize(), buf.data() + 1); + + return buf; +} + +template <typename T> +void toEncryptedBinData(StringData field, EncryptedBinDataType dt, T t, BSONObjBuilder* builder) { + auto buf = toEncryptedVector(dt, t); + + builder->appendBinData(field, buf.size(), BinDataType::Encrypt, buf.data()); +} + +constexpr auto kIndexKeyId = "12345678-1234-9876-1234-123456789012"_sd; +constexpr auto kUserKeyId = "ABCDEFAB-1234-9876-1234-123456789012"_sd; +static UUID indexKeyId = uassertStatusOK(UUID::parse(kIndexKeyId.toString())); +static UUID userKeyId = uassertStatusOK(UUID::parse(kUserKeyId.toString())); + +std::vector<char> testValue = {0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19}; +std::vector<char> testValue2 = {0x20, 0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27, 0x28, 0x29}; + +const FLEIndexKey& getIndexKey() { + static std::string indexVec = hexblob::decode( + "7dbfebc619aa68a659f64b8e23ccd21644ac326cb74a26840c3d2420176c40ae088294d00ad6cae9684237b21b754cf503f085c25cd320bf035c3417416e1e6fe3d9219f79586582112740b2add88e1030d91926ae8afc13ee575cfb8bb965b7"_sd); + static FLEIndexKey indexKey(KeyMaterial(indexVec.begin(), indexVec.end())); + return indexKey; +} + +const FLEUserKey& getUserKey() { + static std::string userVec = hexblob::decode( + "a7ddbc4c8be00d51f68d9d8e485f351c8edc8d2206b24d8e0e1816d005fbe520e489125047d647b0d8684bfbdbf09c304085ed086aba6c2b2b1677ccc91ced8847a733bf5e5682c84b3ee7969e4a5fe0e0c21e5e3ee190595a55f83147d8de2a"_sd); + static FLEUserKey userKey(KeyMaterial(userVec.begin(), userVec.end())); + return userKey; +} + + +BSONObj generateFFP(StringData path, int value) { + auto indexKey = getIndexKey(); + FLEIndexKeyAndId indexKeyAndId(indexKey.data, indexKeyId); + auto userKey = getUserKey(); + FLEUserKeyAndId userKeyAndId(userKey.data, indexKeyId); + + BSONObj doc = BSON("value" << value); + auto element = doc.firstElement(); + auto fpp = FLEClientCrypto::serializeFindPayload(indexKeyAndId, userKeyAndId, element, 0); + + BSONObjBuilder builder; + toEncryptedBinData(path, EncryptedBinDataType::kFLE2FindEqualityPayload, fpp, &builder); + return builder.obj(); +} + +class FLEServerHighCardRewriteTest : public unittest::Test { +public: + FLEServerHighCardRewriteTest() {} + + void setUp() override {} + + void tearDown() override {} + +protected: + BasicMockFLEQueryRewriter _mock; +}; + + +TEST_F(FLEServerHighCardRewriteTest, HighCard_TopLevel_Equality) { + _mock.setForceHighCardinalityForTest(); + + auto match = generateFFP("ssn", 1); + auto expected = fromjson(R"({ + "$expr": { + "$_internalFleEq": { + "field": "$ssn", + "edc": { + "$binary": { + "base64": "CEWSmQID7SfwyAUI3ZkSFkATKryDQfnxXEOGad5d4Rsg", + "subType": "6" + } + }, + "counter": { + "$numberLong": "0" + }, + "server": { + "$binary": { + "base64": "COuac/eRLYakKX6B0vZ1r3QodOQFfjqJD+xlGiPu4/Ps", + "subType": "6" + } + } + } + } +})"); + + auto actual = _mock.rewriteMatchExpressionForTest(match); + ASSERT_BSONOBJ_EQ(actual, expected); +} + + +TEST_F(FLEServerHighCardRewriteTest, HighCard_TopLevel_In) { + _mock.setForceHighCardinalityForTest(); + + auto ffp1 = generateFFP("ssn", 1); + auto ffp2 = generateFFP("ssn", 2); + auto ffp3 = generateFFP("ssn", 3); + auto expected = fromjson(R"({ + "$or": [ + { + "$expr": { + "$_internalFleEq": { + "field": "$ssn", + "edc": { + "$binary": { + "base64": "CEWSmQID7SfwyAUI3ZkSFkATKryDQfnxXEOGad5d4Rsg", + "subType": "6" + } + }, + "counter": { + "$numberLong": "0" + }, + "server": { + "$binary": { + "base64": "COuac/eRLYakKX6B0vZ1r3QodOQFfjqJD+xlGiPu4/Ps", + "subType": "6" + } + } + } + } + }, + { + "$expr": { + "$_internalFleEq": { + "field": "$ssn", + "edc": { + "$binary": { + "base64": "CLpCo6rNuYMVT+6n1HCX15MNrVYDNqf6udO46ayo43Sw", + "subType": "6" + } + }, + "counter": { + "$numberLong": "0" + }, + "server": { + "$binary": { + "base64": "COuac/eRLYakKX6B0vZ1r3QodOQFfjqJD+xlGiPu4/Ps", + "subType": "6" + } + } + } + } + }, + { + "$expr": { + "$_internalFleEq": { + "field": "$ssn", + "edc": { + "$binary": { + "base64": "CPi44oCQHnNDeRqHsNLzbdCeHt2DK/wCly0g2dxU5fqN", + "subType": "6" + } + }, + "counter": { + "$numberLong": "0" + }, + "server": { + "$binary": { + "base64": "COuac/eRLYakKX6B0vZ1r3QodOQFfjqJD+xlGiPu4/Ps", + "subType": "6" + } + } + } + } + } + ] +})"); + + auto match = + BSON("ssn" << BSON("$in" << BSON_ARRAY(ffp1.firstElement() + << ffp2.firstElement() << ffp3.firstElement()))); + + auto actual = _mock.rewriteMatchExpressionForTest(match); + ASSERT_BSONOBJ_EQ(actual, expected); +} + + +TEST_F(FLEServerHighCardRewriteTest, HighCard_TopLevel_Expr) { + + _mock.setForceHighCardinalityForTest(); + + auto ffp = generateFFP("$ssn", 1); + int len; + auto v = ffp.firstElement().binDataClean(len); + auto match = BSON("$expr" << BSON("$eq" << BSON_ARRAY(ffp.firstElement().fieldName() + << BSONBinData(v, len, Encrypt)))); + + auto expected = fromjson(R"({ "$expr": { + "$_internalFleEq": { + "field": "$ssn", + "edc": { + "$binary": { + "base64": "CEWSmQID7SfwyAUI3ZkSFkATKryDQfnxXEOGad5d4Rsg", + "subType": "6" + } + }, + "counter": { + "$numberLong": "0" + }, + "server": { + "$binary": { + "base64": "COuac/eRLYakKX6B0vZ1r3QodOQFfjqJD+xlGiPu4/Ps", + "subType": "6" + } + } + } + } + })"); + + auto actual = _mock.rewriteMatchExpressionForTest(match); + ASSERT_BSONOBJ_EQ(actual, expected); +} + +TEST_F(FLEServerHighCardRewriteTest, HighCard_TopLevel_Expr_In) { + + _mock.setForceHighCardinalityForTest(); + + auto ffp = generateFFP("$ssn", 1); + int len; + auto v = ffp.firstElement().binDataClean(len); + + auto ffp2 = generateFFP("$ssn", 1); + int len2; + auto v2 = ffp2.firstElement().binDataClean(len2); + + auto match = BSON( + "$expr" << BSON("$in" << BSON_ARRAY(ffp.firstElement().fieldName() + << BSON_ARRAY(BSONBinData(v, len, Encrypt) + << BSONBinData(v2, len2, Encrypt))))); + + auto expected = fromjson(R"({ "$expr": { "$or" : [ { + "$_internalFleEq": { + "field": "$ssn", + "edc": { + "$binary": { + "base64": "CEWSmQID7SfwyAUI3ZkSFkATKryDQfnxXEOGad5d4Rsg", + "subType": "6" + } + }, + "counter": { + "$numberLong": "0" + }, + "server": { + "$binary": { + "base64": "COuac/eRLYakKX6B0vZ1r3QodOQFfjqJD+xlGiPu4/Ps", + "subType": "6" + } + } + }}, + { + "$_internalFleEq": { + "field": "$ssn", + "edc": { + "$binary": { + "base64": "CEWSmQID7SfwyAUI3ZkSFkATKryDQfnxXEOGad5d4Rsg", + "subType": "6" + } + }, + "counter": { + "$numberLong": "0" + }, + "server": { + "$binary": { + "base64": "COuac/eRLYakKX6B0vZ1r3QodOQFfjqJD+xlGiPu4/Ps", + "subType": "6" + } + } + }} + ]}})"); + + auto actual = _mock.rewriteMatchExpressionForTest(match); + ASSERT_BSONOBJ_EQ(actual, expected); +} + } // namespace } // namespace mongo diff --git a/src/mongo/db/query/query_knobs.idl b/src/mongo/db/query/query_knobs.idl index 18851f0ddb9..f894629037f 100644 --- a/src/mongo/db/query/query_knobs.idl +++ b/src/mongo/db/query/query_knobs.idl @@ -863,6 +863,14 @@ server_parameters: gt: 0 lt: 16777216 + internalQueryFLEAlwaysUseHighCardinalityMode: + description: "Boolean flag to force FLE to always use low selectivity mode" + set_at: [ startup, runtime ] + cpp_varname: "internalQueryFLEAlwaysUseHighCardinalityMode" + cpp_vartype: AtomicWord<bool> + default: + expr: false + # Note for adding additional query knobs: # # When adding a new query knob, you should consider whether or not you need to add an 'on_update' diff --git a/src/mongo/db/query/sbe_stage_builder_expression.cpp b/src/mongo/db/query/sbe_stage_builder_expression.cpp index d422d5a0dd7..9fd4aba048e 100644 --- a/src/mongo/db/query/sbe_stage_builder_expression.cpp +++ b/src/mongo/db/query/sbe_stage_builder_expression.cpp @@ -377,6 +377,7 @@ public: void visit(const ExpressionLn* expr) final {} void visit(const ExpressionLog* expr) final {} void visit(const ExpressionLog10* expr) final {} + void visit(const ExpressionInternalFLEEqual* expr) final {} void visit(const ExpressionMap* expr) final {} void visit(const ExpressionMeta* expr) final {} void visit(const ExpressionMod* expr) final {} @@ -609,6 +610,7 @@ public: void visit(const ExpressionLn* expr) final {} void visit(const ExpressionLog* expr) final {} void visit(const ExpressionLog10* expr) final {} + void visit(const ExpressionInternalFLEEqual* expr) final {} void visit(const ExpressionMap* expr) final {} void visit(const ExpressionMeta* expr) final {} void visit(const ExpressionMod* expr) final {} @@ -2316,6 +2318,9 @@ public: _context->pushExpr( sbe::makeE<sbe::ELocalBind>(frameId, std::move(binds), std::move(log10Expr))); } + void visit(const ExpressionInternalFLEEqual* expr) final { + unsupportedExpression("$_internalFleEq"); + } void visit(const ExpressionMap* expr) final { unsupportedExpression("$map"); } diff --git a/src/mongo/db/s/collection_sharding_state.cpp b/src/mongo/db/s/collection_sharding_state.cpp index b8fd37f965e..03cb0dd10fa 100644 --- a/src/mongo/db/s/collection_sharding_state.cpp +++ b/src/mongo/db/s/collection_sharding_state.cpp @@ -157,10 +157,6 @@ std::vector<NamespaceString> CollectionShardingState::getCollectionNames(Operati return collectionsMap->getCollectionNames(); } -void CollectionShardingState::clearFilteringMetadata_DoNotUseIt(OperationContext* opCtx) { - clearFilteringMetadata(opCtx); -} - void CollectionShardingStateFactory::set(ServiceContext* service, std::unique_ptr<CollectionShardingStateFactory> factory) { auto& collectionsMap = CollectionShardingStateMap::get(service); diff --git a/src/mongo/db/s/collection_sharding_state.h b/src/mongo/db/s/collection_sharding_state.h index e15af4cfb31..575ccb85f2f 100644 --- a/src/mongo/db/s/collection_sharding_state.h +++ b/src/mongo/db/s/collection_sharding_state.h @@ -159,19 +159,6 @@ public: * Returns the number of ranges scheduled for deletion on the collection. */ virtual size_t numberOfRangesScheduledForDeletion() const = 0; - - /** - * Public interface to clear the filtering metadata associated to a namespace. Do not use it - * without the consent of the Sharding Team, please. - */ - void clearFilteringMetadata_DoNotUseIt(OperationContext* opCtx); - -private: - /** - * Private interface to clear the filtering metadata. Please, do not make it public. See - * clearFilteringMetadata_DoNotUseIt for more information. - */ - virtual void clearFilteringMetadata(OperationContext* opCtx) = 0; }; /** diff --git a/src/mongo/db/s/collection_sharding_state_factory_standalone.cpp b/src/mongo/db/s/collection_sharding_state_factory_standalone.cpp index 995a5f28c72..6422dc066f1 100644 --- a/src/mongo/db/s/collection_sharding_state_factory_standalone.cpp +++ b/src/mongo/db/s/collection_sharding_state_factory_standalone.cpp @@ -67,9 +67,6 @@ public: size_t numberOfRangesScheduledForDeletion() const override { return 0; } - -private: - void clearFilteringMetadata(OperationContext* opCtx) override {} }; } // namespace diff --git a/src/mongo/db/session_catalog_mongod.cpp b/src/mongo/db/session_catalog_mongod.cpp index fbfa223d80c..50b63c0f390 100644 --- a/src/mongo/db/session_catalog_mongod.cpp +++ b/src/mongo/db/session_catalog_mongod.cpp @@ -379,23 +379,55 @@ void createTransactionTable(OperationContext* opCtx) { auto createCollectionStatus = storageInterface->createCollection( opCtx, NamespaceString::kSessionTransactionsTableNamespace, options); + auto internalTransactionsFlagEnabled = + feature_flags::gFeatureFlagInternalTransactions.isEnabled( + serverGlobalParams.featureCompatibility); + + // This flag is off by default and only exists to facilitate creating the partial index more + // easily, so we don't tie it to FCV. This overrides the internal transactions feature flag. + auto alwaysCreateIndexFlagEnabled = + feature_flags::gFeatureFlagAlwaysCreateConfigTransactionsPartialIndexOnStepUp + .isEnabledAndIgnoreFCV(); + if (createCollectionStatus == ErrorCodes::NamespaceExists) { - if (!feature_flags::gFeatureFlagInternalTransactions.isEnabled( - serverGlobalParams.featureCompatibility)) { + if (!internalTransactionsFlagEnabled && !alwaysCreateIndexFlagEnabled) { return; } - AutoGetCollection autoColl( - opCtx, NamespaceString::kSessionTransactionsTableNamespace, LockMode::MODE_IS); + bool collectionIsEmpty = false; + { + AutoGetCollection autoColl( + opCtx, NamespaceString::kSessionTransactionsTableNamespace, LockMode::MODE_IS); + invariant(autoColl); + + if (autoColl->getIndexCatalog()->findIndexByName( + opCtx, MongoDSessionCatalog::kConfigTxnsPartialIndexName)) { + // Index already exists, so there's nothing to do. + return; + } + + collectionIsEmpty = autoColl->isEmpty(opCtx); + } + + if (!collectionIsEmpty) { + // Unless explicitly enabled, don't create the index to avoid delaying step up. + if (alwaysCreateIndexFlagEnabled) { + AutoGetCollection autoColl( + opCtx, NamespaceString::kSessionTransactionsTableNamespace, LockMode::MODE_X); + IndexBuildsCoordinator::get(opCtx)->createIndex( + opCtx, + autoColl->uuid(), + MongoDSessionCatalog::getConfigTxnPartialIndexSpec(), + IndexBuildsManager::IndexConstraints::kEnforce, + false /* fromMigration */); + } - // During failover recovery it is possible that the collection is created, but the partial - // index is not since they are recorded as separate oplog entries. If it is already created - // or if the collection isn't empty we can return early. - if (autoColl->getIndexCatalog()->findIndexByName( - opCtx, MongoDSessionCatalog::kConfigTxnsPartialIndexName) || - !autoColl->isEmpty(opCtx)) { return; } + + // The index does not exist and the collection is empty, so fall through to create it on the + // empty collection. This can happen after a failover because the collection and index + // creation are recorded as separate oplog entries. } else { uassertStatusOKWithContext(createCollectionStatus, str::stream() @@ -404,8 +436,7 @@ void createTransactionTable(OperationContext* opCtx) { << " collection"); } - if (!feature_flags::gFeatureFlagInternalTransactions.isEnabled( - serverGlobalParams.featureCompatibility)) { + if (!internalTransactionsFlagEnabled && !alwaysCreateIndexFlagEnabled) { return; } diff --git a/src/mongo/idl/basic_types.idl b/src/mongo/idl/basic_types.idl index d124f29cdae..634b05d9539 100644 --- a/src/mongo/idl/basic_types.idl +++ b/src/mongo/idl/basic_types.idl @@ -156,6 +156,13 @@ types: cpp_type: "std::array<std::uint8_t, 16>" deserializer: "mongo::BSONElement::uuid" + bindata_encrypt: + bson_serialization_type: bindata + bindata_subtype: encrypt + description: "A BSON bindata of encrypt sub type" + cpp_type: "std::vector<std::uint8_t>" + deserializer: "mongo::BSONElement::_binDataVector" + uuid: bson_serialization_type: bindata bindata_subtype: uuid diff --git a/src/mongo/idl/generic_argument.idl b/src/mongo/idl/generic_argument.idl index fdf340cbef1..5cb308bd2e1 100644 --- a/src/mongo/idl/generic_argument.idl +++ b/src/mongo/idl/generic_argument.idl @@ -104,6 +104,7 @@ generic_argument_lists: mayBypassWriteBlocking: forward_to_shards: true + generic_reply_field_lists: generic_reply_fields_api_v1: description: "Fields that may appear in any command reply. These are guaranteed backwards-compatible for as long as the server supports API Version 1." diff --git a/src/mongo/s/write_ops/batch_write_op.cpp b/src/mongo/s/write_ops/batch_write_op.cpp index b034bcd0ee4..a61ee3dd4bf 100644 --- a/src/mongo/s/write_ops/batch_write_op.cpp +++ b/src/mongo/s/write_ops/batch_write_op.cpp @@ -289,6 +289,13 @@ void populateCollectionUUIDMismatch(OperationContext* opCtx, } } +int getEncryptionInformationSize(const BatchedCommandRequest& req) { + if (!req.getWriteCommandRequestBase().getEncryptionInformation()) { + return 0; + } + return req.getWriteCommandRequestBase().getEncryptionInformation().get().toBSON().objsize(); +} + } // namespace BatchWriteOp::BatchWriteOp(OperationContext* opCtx, const BatchedCommandRequest& clientRequest) @@ -421,6 +428,7 @@ Status BatchWriteOp::targetBatch( // // The constant 4 is chosen as the size of the BSON representation of the stmtId. const int writeSizeBytes = getWriteSizeBytes(writeOp) + + getEncryptionInformationSize(_clientRequest) + write_ops::kWriteCommandBSONArrayPerElementOverheadBytes + (_batchTxnNum ? write_ops::kWriteCommandBSONArrayPerElementOverheadBytes + 4 : 0); @@ -583,6 +591,9 @@ BatchedCommandRequest BatchWriteOp::buildBatchRequest(const TargetedWriteBatch& wcb.setOrdered(_clientRequest.getWriteCommandRequestBase().getOrdered()); wcb.setCollectionUUID(_clientRequest.getWriteCommandRequestBase().getCollectionUUID()); + wcb.setEncryptionInformation( + _clientRequest.getWriteCommandRequestBase().getEncryptionInformation()); + if (targeter.isShardedTimeSeriesBucketsNamespace() && !_clientRequest.getNS().isTimeseriesBucketsCollection()) { wcb.setIsTimeseriesNamespace(true); diff --git a/src/mongo/util/dns_query_posix-impl.h b/src/mongo/util/dns_query_posix-impl.h index 8c39084deaf..93431114f7f 100644 --- a/src/mongo/util/dns_query_posix-impl.h +++ b/src/mongo/util/dns_query_posix-impl.h @@ -188,7 +188,7 @@ public: uasserted(ErrorCodes::DNSProtocolError, "DNS CNAME record could not be decompressed"); } - return std::string(&buf[0], length); + return std::string(&buf[0]); } DNSQueryType getType() const { |