diff options
Diffstat (limited to 'src/mongo/db/query/fle/server_rewrite.cpp')
-rw-r--r-- | src/mongo/db/query/fle/server_rewrite.cpp | 355 |
1 files changed, 298 insertions, 57 deletions
diff --git a/src/mongo/db/query/fle/server_rewrite.cpp b/src/mongo/db/query/fle/server_rewrite.cpp index f4f02bcb383..2aeb99a4061 100644 --- a/src/mongo/db/query/fle/server_rewrite.cpp +++ b/src/mongo/db/query/fle/server_rewrite.cpp @@ -32,6 +32,7 @@ #include <memory> +#include "mongo/bson/bsonmisc.h" #include "mongo/bson/bsonobj.h" #include "mongo/bson/bsonobjbuilder.h" #include "mongo/bson/bsontypes.h" @@ -48,9 +49,14 @@ #include "mongo/db/pipeline/expression.h" #include "mongo/db/query/collation/collator_factory_interface.h" #include "mongo/db/service_context.h" +#include "mongo/logv2/log.h" #include "mongo/s/grid.h" #include "mongo/s/transaction_router_resource_yielder.h" #include "mongo/util/assert_util.h" +#include "mongo/util/intrusive_counter.h" + + +#define MONGO_LOGV2_DEFAULT_COMPONENT ::mongo::logv2::LogComponent::kQuery namespace mongo::fle { @@ -68,6 +74,56 @@ std::unique_ptr<CollatorInterface> collatorFromBSON(OperationContext* opCtx, } namespace { +template <typename PayloadT> +boost::intrusive_ptr<ExpressionInternalFLEEqual> generateFleEqualMatch(StringData path, + const PayloadT& ffp, + ExpressionContext* expCtx) { + // Generate { $_internalFleEq: { field: "$field_name", server: f_3, counter: cm, edc: k_EDC] } + auto tokens = ParsedFindPayload(ffp); + + uassert(6672401, + "Missing required field server encryption token in find payload", + tokens.serverToken.has_value()); + + return make_intrusive<ExpressionInternalFLEEqual>( + expCtx, + ExpressionFieldPath::createPathFromString( + expCtx, path.toString(), expCtx->variablesParseState), + tokens.serverToken.get().data, + tokens.maxCounter.value_or(0LL), + tokens.edcToken.data); +} + + +template <typename PayloadT> +std::unique_ptr<ExpressionInternalFLEEqual> generateFleEqualMatchUnique(StringData path, + const PayloadT& ffp, + ExpressionContext* expCtx) { + // Generate { $_internalFleEq: { field: "$field_name", server: f_3, counter: cm, edc: k_EDC] } + auto tokens = ParsedFindPayload(ffp); + + uassert(6672419, + "Missing required field server encryption token in find payload", + tokens.serverToken.has_value()); + + return std::make_unique<ExpressionInternalFLEEqual>( + expCtx, + ExpressionFieldPath::createPathFromString( + expCtx, path.toString(), expCtx->variablesParseState), + tokens.serverToken.get().data, + tokens.maxCounter.value_or(0LL), + tokens.edcToken.data); +} + +std::unique_ptr<MatchExpression> generateFleEqualMatchAndExpr(StringData path, + const BSONElement ffp, + ExpressionContext* expCtx) { + auto fleEqualMatch = generateFleEqualMatch(path, ffp, expCtx); + + return std::make_unique<ExprMatchExpression>(fleEqualMatch, expCtx); +} + + /** * This section defines a mapping from DocumentSources to the dispatch function to appropriately * handle FLE rewriting for that stage. This should be kept in line with code on the client-side @@ -128,7 +184,8 @@ public: * The final output will look like * {$or: [{$in: [tag0, "$__safeContent__"]}, {$in: [tag1, "$__safeContent__"]}, ...]}. */ - std::unique_ptr<Expression> rewriteComparisonsToEncryptedField( + std::unique_ptr<Expression> rewriteInToEncryptedField( + const Expression* leftExpr, const std::vector<boost::intrusive_ptr<Expression>>& equalitiesList) { size_t numFFPs = 0; std::vector<boost::intrusive_ptr<Expression>> orListElems; @@ -140,11 +197,122 @@ public: continue; } - // ... rewrite the payload to a list of tags... numFFPs++; + } + } + + // Finally, construct an $or of all of the $ins. + if (numFFPs == 0) { + return nullptr; + } + + uassert( + 6334102, + "If any elements in an comparison expression are encrypted, then all elements should " + "be encrypted.", + numFFPs == equalitiesList.size()); + + auto leftFieldPath = dynamic_cast<const ExpressionFieldPath*>(leftExpr); + uassert(6672417, + "$in is only supported with Queryable Encryption when the first argument is a " + "field path", + leftFieldPath != nullptr); + + if (!queryRewriter->isForceHighCardinality()) { + try { + for (auto& equality : equalitiesList) { + // For each expression representing a FleFindPayload... + if (auto constChild = dynamic_cast<ExpressionConstant*>(equality.get())) { + // ... rewrite the payload to a list of tags... + auto tags = queryRewriter->rewritePayloadAsTags(constChild->getValue()); + for (auto&& tagElt : tags) { + // ... and for each tag, construct expression {$in: [tag, + // "$__safeContent__"]}. + std::vector<boost::intrusive_ptr<Expression>> inVec{ + ExpressionConstant::create(queryRewriter->expCtx(), tagElt), + ExpressionFieldPath::createPathFromString( + queryRewriter->expCtx(), + kSafeContent, + queryRewriter->expCtx()->variablesParseState)}; + orListElems.push_back(make_intrusive<ExpressionIn>( + queryRewriter->expCtx(), std::move(inVec))); + } + } + } + + didRewrite = true; + + return std::make_unique<ExpressionOr>(queryRewriter->expCtx(), + std::move(orListElems)); + } catch (const ExceptionFor<ErrorCodes::FLEMaxTagLimitExceeded>& ex) { + LOGV2_DEBUG(6672403, + 2, + "FLE Max tag limit hit during aggregation $in rewrite", + "__error__"_attr = ex.what()); + + if (queryRewriter->getHighCardinalityMode() != + FLEQueryRewriter::HighCardinalityMode::kUseIfNeeded) { + throw; + } + + // fall through + } + } + + for (auto& equality : equalitiesList) { + if (auto constChild = dynamic_cast<ExpressionConstant*>(equality.get())) { + auto fleEqExpr = generateFleEqualMatch( + leftFieldPath->getFieldPathWithoutCurrentPrefix().fullPath(), + constChild->getValue(), + queryRewriter->expCtx()); + orListElems.push_back(fleEqExpr); + } + } + + didRewrite = true; + return std::make_unique<ExpressionOr>(queryRewriter->expCtx(), std::move(orListElems)); + } + + // Rewrite a [$eq : [$fieldpath, constant]] or [$eq: [constant, $fieldpath]] + // to _internalFleEq: {field: $fieldpath, edc: edcToken, counter: N, server: serverToken} + std::unique_ptr<Expression> rewriteComparisonsToEncryptedField( + const std::vector<boost::intrusive_ptr<Expression>>& equalitiesList) { + + auto leftConstant = dynamic_cast<ExpressionConstant*>(equalitiesList[0].get()); + auto rightConstant = dynamic_cast<ExpressionConstant*>(equalitiesList[1].get()); + + bool isLeftFFP = leftConstant && queryRewriter->isFleFindPayload(leftConstant->getValue()); + bool isRightFFP = + rightConstant && queryRewriter->isFleFindPayload(rightConstant->getValue()); + + uassert(6334100, + "Cannot compare two encrypted constants to each other", + !(isLeftFFP && isRightFFP)); + + // No FLE Find Payload + if (!isLeftFFP && !isRightFFP) { + return nullptr; + } + + auto leftFieldPath = dynamic_cast<ExpressionFieldPath*>(equalitiesList[0].get()); + auto rightFieldPath = dynamic_cast<ExpressionFieldPath*>(equalitiesList[1].get()); + + uassert( + 6672413, + "Queryable Encryption only supports comparisons between a field path and a constant", + leftFieldPath || rightFieldPath); + + auto fieldPath = leftFieldPath ? leftFieldPath : rightFieldPath; + auto constChild = isLeftFFP ? leftConstant : rightConstant; + + if (!queryRewriter->isForceHighCardinality()) { + try { + std::vector<boost::intrusive_ptr<Expression>> orListElems; + auto tags = queryRewriter->rewritePayloadAsTags(constChild->getValue()); for (auto&& tagElt : tags) { - // ... and for each tag, construct expression {$in: [tag, "$__safeContent__"]}. + // ... and for each tag, construct expression {$in: [tag, + // "$__safeContent__"]}. std::vector<boost::intrusive_ptr<Expression>> inVec{ ExpressionConstant::create(queryRewriter->expCtx(), tagElt), ExpressionFieldPath::createPathFromString( @@ -154,21 +322,33 @@ public: orListElems.push_back( make_intrusive<ExpressionIn>(queryRewriter->expCtx(), std::move(inVec))); } + + didRewrite = true; + return std::make_unique<ExpressionOr>(queryRewriter->expCtx(), + std::move(orListElems)); + + } catch (const ExceptionFor<ErrorCodes::FLEMaxTagLimitExceeded>& ex) { + LOGV2_DEBUG(6672409, + 2, + "FLE Max tag limit hit during query $in rewrite", + "__error__"_attr = ex.what()); + + if (queryRewriter->getHighCardinalityMode() != + FLEQueryRewriter::HighCardinalityMode::kUseIfNeeded) { + throw; + } + + // fall through } } - // Finally, construct an $or of all of the $ins. - if (numFFPs == 0) { - return nullptr; - } - uassert( - 6334102, - "If any elements in an comparison expression are encrypted, then all elements should " - "be encrypted.", - numFFPs == equalitiesList.size()); + auto fleEqExpr = + generateFleEqualMatchUnique(fieldPath->getFieldPathWithoutCurrentPrefix().fullPath(), + constChild->getValue(), + queryRewriter->expCtx()); didRewrite = true; - return std::make_unique<ExpressionOr>(queryRewriter->expCtx(), std::move(orListElems)); + return fleEqExpr; } std::unique_ptr<Expression> postVisit(Expression* exp) { @@ -177,30 +357,28 @@ public: // ignored when rewrites are done; there is no extra information in that child that // doesn't exist in the FFPs in the $in list. if (auto inList = dynamic_cast<ExpressionArray*>(inExpr->getOperandList()[1].get())) { - return rewriteComparisonsToEncryptedField(inList->getChildren()); + return rewriteInToEncryptedField(inExpr->getOperandList()[0].get(), + inList->getChildren()); } } else if (auto eqExpr = dynamic_cast<ExpressionCompare*>(exp); eqExpr && (eqExpr->getOp() == ExpressionCompare::EQ || eqExpr->getOp() == ExpressionCompare::NE)) { // Rewrite an $eq comparing an encrypted field and an encrypted constant to an $or. - // Either child may be the constant, so try rewriting both. - auto or0 = rewriteComparisonsToEncryptedField({eqExpr->getChildren()[0]}); - auto or1 = rewriteComparisonsToEncryptedField({eqExpr->getChildren()[1]}); - uassert(6334100, "Cannot compare two encrypted constants to each other", !or0 || !or1); + auto newExpr = rewriteComparisonsToEncryptedField(eqExpr->getChildren()); // Neither child is an encrypted constant, and no rewriting needs to be done. - if (!or0 && !or1) { + if (!newExpr) { return nullptr; } // Exactly one child was an encrypted constant. The other child can be ignored; there is // no extra information in that child that doesn't exist in the FFP. if (eqExpr->getOp() == ExpressionCompare::NE) { - std::vector<boost::intrusive_ptr<Expression>> notChild{(or0 ? or0 : or1).release()}; + std::vector<boost::intrusive_ptr<Expression>> notChild{newExpr.release()}; return std::make_unique<ExpressionNot>(queryRewriter->expCtx(), std::move(notChild)); } - return std::move(or0 ? or0 : or1); + return newExpr; } return nullptr; @@ -213,11 +391,14 @@ public: BSONObj rewriteEncryptedFilter(const FLEStateCollectionReader& escReader, const FLEStateCollectionReader& eccReader, boost::intrusive_ptr<ExpressionContext> expCtx, - BSONObj filter) { + BSONObj filter, + HighCardinalityModeAllowed mode) { + if (auto rewritten = - FLEQueryRewriter(expCtx, escReader, eccReader).rewriteMatchExpression(filter)) { + FLEQueryRewriter(expCtx, escReader, eccReader, mode).rewriteMatchExpression(filter)) { return rewritten.get(); } + return filter; } @@ -273,16 +454,18 @@ public: FilterRewrite(boost::intrusive_ptr<ExpressionContext> expCtx, const NamespaceString& nss, const EncryptionInformation& encryptInfo, - const BSONObj toRewrite) - : RewriteBase(expCtx, nss, encryptInfo), userFilter(toRewrite) {} + const BSONObj toRewrite, + HighCardinalityModeAllowed mode) + : RewriteBase(expCtx, nss, encryptInfo), userFilter(toRewrite), _mode(mode) {} ~FilterRewrite(){}; void doRewrite(FLEStateCollectionReader& escReader, FLEStateCollectionReader& eccReader) final { - rewrittenFilter = rewriteEncryptedFilter(escReader, eccReader, expCtx, userFilter); + rewrittenFilter = rewriteEncryptedFilter(escReader, eccReader, expCtx, userFilter, _mode); } const BSONObj userFilter; BSONObj rewrittenFilter; + HighCardinalityModeAllowed _mode; }; // This helper executes the rewrite(s) inside a transaction. The transaction runs in a separate @@ -324,7 +507,8 @@ BSONObj rewriteEncryptedFilterInsideTxn(FLEQueryInterface* queryImpl, StringData db, const EncryptedFieldConfig& efc, boost::intrusive_ptr<ExpressionContext> expCtx, - BSONObj filter) { + BSONObj filter, + HighCardinalityModeAllowed mode) { auto makeCollectionReader = [&](FLEQueryInterface* queryImpl, const StringData& coll) { NamespaceString nss(db, coll); auto docCount = queryImpl->countDocuments(nss); @@ -332,7 +516,8 @@ BSONObj rewriteEncryptedFilterInsideTxn(FLEQueryInterface* queryImpl, }; auto escReader = makeCollectionReader(queryImpl, efc.getEscCollection().get()); auto eccReader = makeCollectionReader(queryImpl, efc.getEccCollection().get()); - return rewriteEncryptedFilter(escReader, eccReader, expCtx, filter); + + return rewriteEncryptedFilter(escReader, eccReader, expCtx, filter, mode); } BSONObj rewriteQuery(OperationContext* opCtx, @@ -340,8 +525,9 @@ BSONObj rewriteQuery(OperationContext* opCtx, const NamespaceString& nss, const EncryptionInformation& info, BSONObj filter, - GetTxnCallback getTransaction) { - auto sharedBlock = std::make_shared<FilterRewrite>(expCtx, nss, info, filter); + GetTxnCallback getTransaction, + HighCardinalityModeAllowed mode) { + auto sharedBlock = std::make_shared<FilterRewrite>(expCtx, nss, info, filter, mode); doFLERewriteInTxn(opCtx, sharedBlock, getTransaction); return sharedBlock->rewrittenFilter.getOwned(); } @@ -365,7 +551,8 @@ void processFindCommand(OperationContext* opCtx, nss, findCommand->getEncryptionInformation().get(), findCommand->getFilter().getOwned(), - getTransaction)); + getTransaction, + HighCardinalityModeAllowed::kAllow)); // The presence of encryptionInformation is a signal that this is a FLE request that requires // special processing. Once we've rewritten the query, it's no longer a "special" FLE query, but // a normal query that can be executed by the query system like any other, so remove @@ -389,7 +576,8 @@ void processCountCommand(OperationContext* opCtx, nss, countCommand->getEncryptionInformation().get(), countCommand->getQuery().getOwned(), - getTxn)); + getTxn, + HighCardinalityModeAllowed::kAllow)); // The presence of encryptionInformation is a signal that this is a FLE request that requires // special processing. Once we've rewritten the query, it's no longer a "special" FLE query, but // a normal query that can be executed by the query system like any other, so remove @@ -504,59 +692,112 @@ std::vector<Value> FLEQueryRewriter::rewritePayloadAsTags(Value fleFindPayload) return tagVec; } -std::unique_ptr<InMatchExpression> FLEQueryRewriter::rewriteEq( - const EqualityMatchExpression* expr) { + +std::unique_ptr<MatchExpression> FLEQueryRewriter::rewriteEq(const EqualityMatchExpression* expr) { auto ffp = expr->getData(); if (!isFleFindPayload(ffp)) { return nullptr; } - auto obj = rewritePayloadAsTags(ffp); - - auto tags = std::vector<BSONElement>(); - obj.elems(tags); + if (_mode != HighCardinalityMode::kForceAlways) { + try { + auto obj = rewritePayloadAsTags(ffp); + + auto tags = std::vector<BSONElement>(); + obj.elems(tags); + + auto inExpr = std::make_unique<InMatchExpression>(kSafeContent); + inExpr->setBackingBSON(std::move(obj)); + auto status = inExpr->setEqualities(std::move(tags)); + uassertStatusOK(status); + _rewroteLastExpression = true; + return inExpr; + } catch (const ExceptionFor<ErrorCodes::FLEMaxTagLimitExceeded>& ex) { + LOGV2_DEBUG(6672410, + 2, + "FLE Max tag limit hit during query $eq rewrite", + "__error__"_attr = ex.what()); + + if (_mode != HighCardinalityMode::kUseIfNeeded) { + throw; + } - auto inExpr = std::make_unique<InMatchExpression>(kSafeContent); - inExpr->setBackingBSON(std::move(obj)); - auto status = inExpr->setEqualities(std::move(tags)); - uassertStatusOK(status); + // fall through + } + } + auto exprMatch = generateFleEqualMatchAndExpr(expr->path(), ffp, _expCtx.get()); _rewroteLastExpression = true; - return inExpr; + return exprMatch; } -std::unique_ptr<InMatchExpression> FLEQueryRewriter::rewriteIn(const InMatchExpression* expr) { - auto backingBSONBuilder = BSONArrayBuilder(); +std::unique_ptr<MatchExpression> FLEQueryRewriter::rewriteIn(const InMatchExpression* expr) { size_t numFFPs = 0; for (auto& eq : expr->getEqualities()) { if (isFleFindPayload(eq)) { - auto obj = rewritePayloadAsTags(eq); ++numFFPs; - for (auto&& elt : obj) { - backingBSONBuilder.append(elt); - } } } + if (numFFPs == 0) { return nullptr; } + // All elements in an encrypted $in expression should be FFPs. uassert( 6329400, "If any elements in a $in expression are encrypted, then all elements should be encrypted.", numFFPs == expr->getEqualities().size()); - auto backingBSON = backingBSONBuilder.arr(); - auto allTags = std::vector<BSONElement>(); - backingBSON.elems(allTags); + if (_mode != HighCardinalityMode::kForceAlways) { + + try { + auto backingBSONBuilder = BSONArrayBuilder(); + + for (auto& eq : expr->getEqualities()) { + auto obj = rewritePayloadAsTags(eq); + for (auto&& elt : obj) { + backingBSONBuilder.append(elt); + } + } - auto inExpr = std::make_unique<InMatchExpression>(kSafeContent); - inExpr->setBackingBSON(std::move(backingBSON)); - auto status = inExpr->setEqualities(std::move(allTags)); - uassertStatusOK(status); + auto backingBSON = backingBSONBuilder.arr(); + auto allTags = std::vector<BSONElement>(); + backingBSON.elems(allTags); + + auto inExpr = std::make_unique<InMatchExpression>(kSafeContent); + inExpr->setBackingBSON(std::move(backingBSON)); + auto status = inExpr->setEqualities(std::move(allTags)); + uassertStatusOK(status); + + _rewroteLastExpression = true; + return inExpr; + + } catch (const ExceptionFor<ErrorCodes::FLEMaxTagLimitExceeded>& ex) { + LOGV2_DEBUG(6672411, + 2, + "FLE Max tag limit hit during query $in rewrite", + "__error__"_attr = ex.what()); + + if (_mode != HighCardinalityMode::kUseIfNeeded) { + throw; + } + + // fall through + } + } + + std::vector<std::unique_ptr<MatchExpression>> matches; + matches.reserve(numFFPs); + + for (auto& eq : expr->getEqualities()) { + auto exprMatch = generateFleEqualMatchAndExpr(expr->path(), eq, _expCtx.get()); + matches.push_back(std::move(exprMatch)); + } + auto orExpr = std::make_unique<OrMatchExpression>(std::move(matches)); _rewroteLastExpression = true; - return inExpr; + return orExpr; } } // namespace mongo::fle |