summaryrefslogtreecommitdiff
path: root/src/mongo/db/query/fle/server_rewrite.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/mongo/db/query/fle/server_rewrite.cpp')
-rw-r--r--src/mongo/db/query/fle/server_rewrite.cpp353
1 files changed, 296 insertions, 57 deletions
diff --git a/src/mongo/db/query/fle/server_rewrite.cpp b/src/mongo/db/query/fle/server_rewrite.cpp
index 185a979ec3d..1660c368093 100644
--- a/src/mongo/db/query/fle/server_rewrite.cpp
+++ b/src/mongo/db/query/fle/server_rewrite.cpp
@@ -27,11 +27,13 @@
* it in the license file.
*/
+#define MONGO_LOGV2_DEFAULT_COMPONENT ::mongo::logv2::LogComponent::kQuery
#include "mongo/db/query/fle/server_rewrite.h"
#include <memory>
+#include "mongo/bson/bsonmisc.h"
#include "mongo/bson/bsonobj.h"
#include "mongo/bson/bsonobjbuilder.h"
#include "mongo/bson/bsontypes.h"
@@ -48,9 +50,11 @@
#include "mongo/db/pipeline/expression.h"
#include "mongo/db/query/collation/collator_factory_interface.h"
#include "mongo/db/service_context.h"
+#include "mongo/logv2/log.h"
#include "mongo/s/grid.h"
#include "mongo/s/transaction_router_resource_yielder.h"
#include "mongo/util/assert_util.h"
+#include "mongo/util/intrusive_counter.h"
namespace mongo::fle {
@@ -68,6 +72,56 @@ std::unique_ptr<CollatorInterface> collatorFromBSON(OperationContext* opCtx,
}
namespace {
+template <typename PayloadT>
+boost::intrusive_ptr<ExpressionInternalFLEEqual> generateFleEqualMatch(StringData path,
+ const PayloadT& ffp,
+ ExpressionContext* expCtx) {
+ // Generate { $_internalFleEq: { field: "$field_name", server: f_3, counter: cm, edc: k_EDC] }
+ auto tokens = ParsedFindPayload(ffp);
+
+ uassert(6672401,
+ "Missing required field server encryption token in find payload",
+ tokens.serverToken.has_value());
+
+ return make_intrusive<ExpressionInternalFLEEqual>(
+ expCtx,
+ ExpressionFieldPath::createPathFromString(
+ expCtx, path.toString(), expCtx->variablesParseState),
+ tokens.serverToken.get().data,
+ tokens.maxCounter.value_or(0LL),
+ tokens.edcToken.data);
+}
+
+
+template <typename PayloadT>
+std::unique_ptr<ExpressionInternalFLEEqual> generateFleEqualMatchUnique(StringData path,
+ const PayloadT& ffp,
+ ExpressionContext* expCtx) {
+ // Generate { $_internalFleEq: { field: "$field_name", server: f_3, counter: cm, edc: k_EDC] }
+ auto tokens = ParsedFindPayload(ffp);
+
+ uassert(6672419,
+ "Missing required field server encryption token in find payload",
+ tokens.serverToken.has_value());
+
+ return std::make_unique<ExpressionInternalFLEEqual>(
+ expCtx,
+ ExpressionFieldPath::createPathFromString(
+ expCtx, path.toString(), expCtx->variablesParseState),
+ tokens.serverToken.get().data,
+ tokens.maxCounter.value_or(0LL),
+ tokens.edcToken.data);
+}
+
+std::unique_ptr<MatchExpression> generateFleEqualMatchAndExpr(StringData path,
+ const BSONElement ffp,
+ ExpressionContext* expCtx) {
+ auto fleEqualMatch = generateFleEqualMatch(path, ffp, expCtx);
+
+ return std::make_unique<ExprMatchExpression>(fleEqualMatch, expCtx);
+}
+
+
/**
* This section defines a mapping from DocumentSources to the dispatch function to appropriately
* handle FLE rewriting for that stage. This should be kept in line with code on the client-side
@@ -128,7 +182,8 @@ public:
* The final output will look like
* {$or: [{$in: [tag0, "$__safeContent__"]}, {$in: [tag1, "$__safeContent__"]}, ...]}.
*/
- std::unique_ptr<Expression> rewriteComparisonsToEncryptedField(
+ std::unique_ptr<Expression> rewriteInToEncryptedField(
+ const Expression* leftExpr,
const std::vector<boost::intrusive_ptr<Expression>>& equalitiesList) {
size_t numFFPs = 0;
std::vector<boost::intrusive_ptr<Expression>> orListElems;
@@ -140,11 +195,122 @@ public:
continue;
}
- // ... rewrite the payload to a list of tags...
numFFPs++;
+ }
+ }
+
+ // Finally, construct an $or of all of the $ins.
+ if (numFFPs == 0) {
+ return nullptr;
+ }
+
+ uassert(
+ 6334102,
+ "If any elements in an comparison expression are encrypted, then all elements should "
+ "be encrypted.",
+ numFFPs == equalitiesList.size());
+
+ auto leftFieldPath = dynamic_cast<const ExpressionFieldPath*>(leftExpr);
+ uassert(6672417,
+ "$in is only supported with Queryable Encryption when the first argument is a "
+ "field path",
+ leftFieldPath != nullptr);
+
+ if (!queryRewriter->isForceHighCardinality()) {
+ try {
+ for (auto& equality : equalitiesList) {
+ // For each expression representing a FleFindPayload...
+ if (auto constChild = dynamic_cast<ExpressionConstant*>(equality.get())) {
+ // ... rewrite the payload to a list of tags...
+ auto tags = queryRewriter->rewritePayloadAsTags(constChild->getValue());
+ for (auto&& tagElt : tags) {
+ // ... and for each tag, construct expression {$in: [tag,
+ // "$__safeContent__"]}.
+ std::vector<boost::intrusive_ptr<Expression>> inVec{
+ ExpressionConstant::create(queryRewriter->expCtx(), tagElt),
+ ExpressionFieldPath::createPathFromString(
+ queryRewriter->expCtx(),
+ kSafeContent,
+ queryRewriter->expCtx()->variablesParseState)};
+ orListElems.push_back(make_intrusive<ExpressionIn>(
+ queryRewriter->expCtx(), std::move(inVec)));
+ }
+ }
+ }
+
+ didRewrite = true;
+
+ return std::make_unique<ExpressionOr>(queryRewriter->expCtx(),
+ std::move(orListElems));
+ } catch (const ExceptionFor<ErrorCodes::FLEMaxTagLimitExceeded>& ex) {
+ LOGV2_DEBUG(6672403,
+ 2,
+ "FLE Max tag limit hit during aggregation $in rewrite",
+ "__error__"_attr = ex.what());
+
+ if (queryRewriter->getHighCardinalityMode() !=
+ FLEQueryRewriter::HighCardinalityMode::kUseIfNeeded) {
+ throw;
+ }
+
+ // fall through
+ }
+ }
+
+ for (auto& equality : equalitiesList) {
+ if (auto constChild = dynamic_cast<ExpressionConstant*>(equality.get())) {
+ auto fleEqExpr = generateFleEqualMatch(
+ leftFieldPath->getFieldPathWithoutCurrentPrefix().fullPath(),
+ constChild->getValue(),
+ queryRewriter->expCtx());
+ orListElems.push_back(fleEqExpr);
+ }
+ }
+
+ didRewrite = true;
+ return std::make_unique<ExpressionOr>(queryRewriter->expCtx(), std::move(orListElems));
+ }
+
+ // Rewrite a [$eq : [$fieldpath, constant]] or [$eq: [constant, $fieldpath]]
+ // to _internalFleEq: {field: $fieldpath, edc: edcToken, counter: N, server: serverToken}
+ std::unique_ptr<Expression> rewriteComparisonsToEncryptedField(
+ const std::vector<boost::intrusive_ptr<Expression>>& equalitiesList) {
+
+ auto leftConstant = dynamic_cast<ExpressionConstant*>(equalitiesList[0].get());
+ auto rightConstant = dynamic_cast<ExpressionConstant*>(equalitiesList[1].get());
+
+ bool isLeftFFP = leftConstant && queryRewriter->isFleFindPayload(leftConstant->getValue());
+ bool isRightFFP =
+ rightConstant && queryRewriter->isFleFindPayload(rightConstant->getValue());
+
+ uassert(6334100,
+ "Cannot compare two encrypted constants to each other",
+ !(isLeftFFP && isRightFFP));
+
+ // No FLE Find Payload
+ if (!isLeftFFP && !isRightFFP) {
+ return nullptr;
+ }
+
+ auto leftFieldPath = dynamic_cast<ExpressionFieldPath*>(equalitiesList[0].get());
+ auto rightFieldPath = dynamic_cast<ExpressionFieldPath*>(equalitiesList[1].get());
+
+ uassert(
+ 6672413,
+ "Queryable Encryption only supports comparisons between a field path and a constant",
+ leftFieldPath || rightFieldPath);
+
+ auto fieldPath = leftFieldPath ? leftFieldPath : rightFieldPath;
+ auto constChild = isLeftFFP ? leftConstant : rightConstant;
+
+ if (!queryRewriter->isForceHighCardinality()) {
+ try {
+ std::vector<boost::intrusive_ptr<Expression>> orListElems;
+
auto tags = queryRewriter->rewritePayloadAsTags(constChild->getValue());
for (auto&& tagElt : tags) {
- // ... and for each tag, construct expression {$in: [tag, "$__safeContent__"]}.
+ // ... and for each tag, construct expression {$in: [tag,
+ // "$__safeContent__"]}.
std::vector<boost::intrusive_ptr<Expression>> inVec{
ExpressionConstant::create(queryRewriter->expCtx(), tagElt),
ExpressionFieldPath::createPathFromString(
@@ -154,21 +320,33 @@ public:
orListElems.push_back(
make_intrusive<ExpressionIn>(queryRewriter->expCtx(), std::move(inVec)));
}
+
+ didRewrite = true;
+ return std::make_unique<ExpressionOr>(queryRewriter->expCtx(),
+ std::move(orListElems));
+
+ } catch (const ExceptionFor<ErrorCodes::FLEMaxTagLimitExceeded>& ex) {
+ LOGV2_DEBUG(6672409,
+ 2,
+ "FLE Max tag limit hit during query $in rewrite",
+ "__error__"_attr = ex.what());
+
+ if (queryRewriter->getHighCardinalityMode() !=
+ FLEQueryRewriter::HighCardinalityMode::kUseIfNeeded) {
+ throw;
+ }
+
+ // fall through
}
}
- // Finally, construct an $or of all of the $ins.
- if (numFFPs == 0) {
- return nullptr;
- }
- uassert(
- 6334102,
- "If any elements in an comparison expression are encrypted, then all elements should "
- "be encrypted.",
- numFFPs == equalitiesList.size());
+ auto fleEqExpr =
+ generateFleEqualMatchUnique(fieldPath->getFieldPathWithoutCurrentPrefix().fullPath(),
+ constChild->getValue(),
+ queryRewriter->expCtx());
didRewrite = true;
- return std::make_unique<ExpressionOr>(queryRewriter->expCtx(), std::move(orListElems));
+ return fleEqExpr;
}
std::unique_ptr<Expression> postVisit(Expression* exp) {
@@ -177,30 +355,28 @@ public:
// ignored when rewrites are done; there is no extra information in that child that
// doesn't exist in the FFPs in the $in list.
if (auto inList = dynamic_cast<ExpressionArray*>(inExpr->getOperandList()[1].get())) {
- return rewriteComparisonsToEncryptedField(inList->getChildren());
+ return rewriteInToEncryptedField(inExpr->getOperandList()[0].get(),
+ inList->getChildren());
}
} else if (auto eqExpr = dynamic_cast<ExpressionCompare*>(exp); eqExpr &&
(eqExpr->getOp() == ExpressionCompare::EQ ||
eqExpr->getOp() == ExpressionCompare::NE)) {
// Rewrite an $eq comparing an encrypted field and an encrypted constant to an $or.
- // Either child may be the constant, so try rewriting both.
- auto or0 = rewriteComparisonsToEncryptedField({eqExpr->getChildren()[0]});
- auto or1 = rewriteComparisonsToEncryptedField({eqExpr->getChildren()[1]});
- uassert(6334100, "Cannot compare two encrypted constants to each other", !or0 || !or1);
+ auto newExpr = rewriteComparisonsToEncryptedField(eqExpr->getChildren());
// Neither child is an encrypted constant, and no rewriting needs to be done.
- if (!or0 && !or1) {
+ if (!newExpr) {
return nullptr;
}
// Exactly one child was an encrypted constant. The other child can be ignored; there is
// no extra information in that child that doesn't exist in the FFP.
if (eqExpr->getOp() == ExpressionCompare::NE) {
- std::vector<boost::intrusive_ptr<Expression>> notChild{(or0 ? or0 : or1).release()};
+ std::vector<boost::intrusive_ptr<Expression>> notChild{newExpr.release()};
return std::make_unique<ExpressionNot>(queryRewriter->expCtx(),
std::move(notChild));
}
- return std::move(or0 ? or0 : or1);
+ return newExpr;
}
return nullptr;
@@ -213,11 +389,14 @@ public:
BSONObj rewriteEncryptedFilter(const FLEStateCollectionReader& escReader,
const FLEStateCollectionReader& eccReader,
boost::intrusive_ptr<ExpressionContext> expCtx,
- BSONObj filter) {
+ BSONObj filter,
+ HighCardinalityModeAllowed mode) {
+
if (auto rewritten =
- FLEQueryRewriter(expCtx, escReader, eccReader).rewriteMatchExpression(filter)) {
+ FLEQueryRewriter(expCtx, escReader, eccReader, mode).rewriteMatchExpression(filter)) {
return rewritten.get();
}
+
return filter;
}
@@ -273,16 +452,18 @@ public:
FilterRewrite(boost::intrusive_ptr<ExpressionContext> expCtx,
const NamespaceString& nss,
const EncryptionInformation& encryptInfo,
- const BSONObj toRewrite)
- : RewriteBase(expCtx, nss, encryptInfo), userFilter(toRewrite) {}
+ const BSONObj toRewrite,
+ HighCardinalityModeAllowed mode)
+ : RewriteBase(expCtx, nss, encryptInfo), userFilter(toRewrite), _mode(mode) {}
~FilterRewrite(){};
void doRewrite(FLEStateCollectionReader& escReader, FLEStateCollectionReader& eccReader) final {
- rewrittenFilter = rewriteEncryptedFilter(escReader, eccReader, expCtx, userFilter);
+ rewrittenFilter = rewriteEncryptedFilter(escReader, eccReader, expCtx, userFilter, _mode);
}
const BSONObj userFilter;
BSONObj rewrittenFilter;
+ HighCardinalityModeAllowed _mode;
};
// This helper executes the rewrite(s) inside a transaction. The transaction runs in a separate
@@ -324,7 +505,8 @@ BSONObj rewriteEncryptedFilterInsideTxn(FLEQueryInterface* queryImpl,
StringData db,
const EncryptedFieldConfig& efc,
boost::intrusive_ptr<ExpressionContext> expCtx,
- BSONObj filter) {
+ BSONObj filter,
+ HighCardinalityModeAllowed mode) {
auto makeCollectionReader = [&](FLEQueryInterface* queryImpl, const StringData& coll) {
NamespaceString nss(db, coll);
auto docCount = queryImpl->countDocuments(nss);
@@ -332,7 +514,8 @@ BSONObj rewriteEncryptedFilterInsideTxn(FLEQueryInterface* queryImpl,
};
auto escReader = makeCollectionReader(queryImpl, efc.getEscCollection().get());
auto eccReader = makeCollectionReader(queryImpl, efc.getEccCollection().get());
- return rewriteEncryptedFilter(escReader, eccReader, expCtx, filter);
+
+ return rewriteEncryptedFilter(escReader, eccReader, expCtx, filter, mode);
}
BSONObj rewriteQuery(OperationContext* opCtx,
@@ -340,8 +523,9 @@ BSONObj rewriteQuery(OperationContext* opCtx,
const NamespaceString& nss,
const EncryptionInformation& info,
BSONObj filter,
- GetTxnCallback getTransaction) {
- auto sharedBlock = std::make_shared<FilterRewrite>(expCtx, nss, info, filter);
+ GetTxnCallback getTransaction,
+ HighCardinalityModeAllowed mode) {
+ auto sharedBlock = std::make_shared<FilterRewrite>(expCtx, nss, info, filter, mode);
doFLERewriteInTxn(opCtx, sharedBlock, getTransaction);
return sharedBlock->rewrittenFilter.getOwned();
}
@@ -365,7 +549,8 @@ void processFindCommand(OperationContext* opCtx,
nss,
findCommand->getEncryptionInformation().get(),
findCommand->getFilter().getOwned(),
- getTransaction));
+ getTransaction,
+ HighCardinalityModeAllowed::kAllow));
// The presence of encryptionInformation is a signal that this is a FLE request that requires
// special processing. Once we've rewritten the query, it's no longer a "special" FLE query, but
// a normal query that can be executed by the query system like any other, so remove
@@ -389,7 +574,8 @@ void processCountCommand(OperationContext* opCtx,
nss,
countCommand->getEncryptionInformation().get(),
countCommand->getQuery().getOwned(),
- getTxn));
+ getTxn,
+ HighCardinalityModeAllowed::kAllow));
// The presence of encryptionInformation is a signal that this is a FLE request that requires
// special processing. Once we've rewritten the query, it's no longer a "special" FLE query, but
// a normal query that can be executed by the query system like any other, so remove
@@ -503,59 +689,112 @@ std::vector<Value> FLEQueryRewriter::rewritePayloadAsTags(Value fleFindPayload)
return tagVec;
}
-std::unique_ptr<InMatchExpression> FLEQueryRewriter::rewriteEq(
- const EqualityMatchExpression* expr) {
+
+std::unique_ptr<MatchExpression> FLEQueryRewriter::rewriteEq(const EqualityMatchExpression* expr) {
auto ffp = expr->getData();
if (!isFleFindPayload(ffp)) {
return nullptr;
}
- auto obj = rewritePayloadAsTags(ffp);
-
- auto tags = std::vector<BSONElement>();
- obj.elems(tags);
+ if (_mode != HighCardinalityMode::kForceAlways) {
+ try {
+ auto obj = rewritePayloadAsTags(ffp);
+
+ auto tags = std::vector<BSONElement>();
+ obj.elems(tags);
+
+ auto inExpr = std::make_unique<InMatchExpression>(kSafeContent);
+ inExpr->setBackingBSON(std::move(obj));
+ auto status = inExpr->setEqualities(std::move(tags));
+ uassertStatusOK(status);
+ _rewroteLastExpression = true;
+ return inExpr;
+ } catch (const ExceptionFor<ErrorCodes::FLEMaxTagLimitExceeded>& ex) {
+ LOGV2_DEBUG(6672410,
+ 2,
+ "FLE Max tag limit hit during query $eq rewrite",
+ "__error__"_attr = ex.what());
+
+ if (_mode != HighCardinalityMode::kUseIfNeeded) {
+ throw;
+ }
- auto inExpr = std::make_unique<InMatchExpression>(kSafeContent);
- inExpr->setBackingBSON(std::move(obj));
- auto status = inExpr->setEqualities(std::move(tags));
- uassertStatusOK(status);
+ // fall through
+ }
+ }
+ auto exprMatch = generateFleEqualMatchAndExpr(expr->path(), ffp, _expCtx.get());
_rewroteLastExpression = true;
- return inExpr;
+ return exprMatch;
}
-std::unique_ptr<InMatchExpression> FLEQueryRewriter::rewriteIn(const InMatchExpression* expr) {
- auto backingBSONBuilder = BSONArrayBuilder();
+std::unique_ptr<MatchExpression> FLEQueryRewriter::rewriteIn(const InMatchExpression* expr) {
size_t numFFPs = 0;
for (auto& eq : expr->getEqualities()) {
if (isFleFindPayload(eq)) {
- auto obj = rewritePayloadAsTags(eq);
++numFFPs;
- for (auto&& elt : obj) {
- backingBSONBuilder.append(elt);
- }
}
}
+
if (numFFPs == 0) {
return nullptr;
}
+
// All elements in an encrypted $in expression should be FFPs.
uassert(
6329400,
"If any elements in a $in expression are encrypted, then all elements should be encrypted.",
numFFPs == expr->getEqualities().size());
- auto backingBSON = backingBSONBuilder.arr();
- auto allTags = std::vector<BSONElement>();
- backingBSON.elems(allTags);
+ if (_mode != HighCardinalityMode::kForceAlways) {
+
+ try {
+ auto backingBSONBuilder = BSONArrayBuilder();
+
+ for (auto& eq : expr->getEqualities()) {
+ auto obj = rewritePayloadAsTags(eq);
+ for (auto&& elt : obj) {
+ backingBSONBuilder.append(elt);
+ }
+ }
+
+ auto backingBSON = backingBSONBuilder.arr();
+ auto allTags = std::vector<BSONElement>();
+ backingBSON.elems(allTags);
+
+ auto inExpr = std::make_unique<InMatchExpression>(kSafeContent);
+ inExpr->setBackingBSON(std::move(backingBSON));
+ auto status = inExpr->setEqualities(std::move(allTags));
+ uassertStatusOK(status);
+
+ _rewroteLastExpression = true;
+ return inExpr;
+
+ } catch (const ExceptionFor<ErrorCodes::FLEMaxTagLimitExceeded>& ex) {
+ LOGV2_DEBUG(6672411,
+ 2,
+ "FLE Max tag limit hit during query $in rewrite",
+ "__error__"_attr = ex.what());
+
+ if (_mode != HighCardinalityMode::kUseIfNeeded) {
+ throw;
+ }
+
+ // fall through
+ }
+ }
+
+ std::vector<std::unique_ptr<MatchExpression>> matches;
+ matches.reserve(numFFPs);
- auto inExpr = std::make_unique<InMatchExpression>(kSafeContent);
- inExpr->setBackingBSON(std::move(backingBSON));
- auto status = inExpr->setEqualities(std::move(allTags));
- uassertStatusOK(status);
+ for (auto& eq : expr->getEqualities()) {
+ auto exprMatch = generateFleEqualMatchAndExpr(expr->path(), eq, _expCtx.get());
+ matches.push_back(std::move(exprMatch));
+ }
+ auto orExpr = std::make_unique<OrMatchExpression>(std::move(matches));
_rewroteLastExpression = true;
- return inExpr;
+ return orExpr;
}
} // namespace mongo::fle