summaryrefslogtreecommitdiff
path: root/src/mongo/db
diff options
context:
space:
mode:
authorDavis Haupt <davis.haupt@mongodb.com>2022-10-03 17:27:11 +0000
committerEvergreen Agent <no-reply@evergreen.mongodb.com>2022-10-03 19:15:24 +0000
commit46ad1bc5186a6cc4d449d077b12951a400595f4e (patch)
treefaa45e1bb8c5ae93363b4334941b62c987ec4e99 /src/mongo/db
parent801e5203e0efcb62ffcef66e22da95d645b2dca2 (diff)
downloadmongo-46ad1bc5186a6cc4d449d077b12951a400595f4e.tar.gz
SERVER-67209 Server-side rewrite for aggregation to tag disjunction
Diffstat (limited to 'src/mongo/db')
-rw-r--r--src/mongo/db/query/fle/encrypted_predicate.cpp15
-rw-r--r--src/mongo/db/query/fle/encrypted_predicate.h3
-rw-r--r--src/mongo/db/query/fle/encrypted_predicate_test_fixtures.h1
-rw-r--r--src/mongo/db/query/fle/equality_predicate.cpp34
-rw-r--r--src/mongo/db/query/fle/range_predicate.cpp27
-rw-r--r--src/mongo/db/query/fle/range_predicate_test.cpp92
6 files changed, 117 insertions, 55 deletions
diff --git a/src/mongo/db/query/fle/encrypted_predicate.cpp b/src/mongo/db/query/fle/encrypted_predicate.cpp
index aeed43cc884..74d2ef8fbf4 100644
--- a/src/mongo/db/query/fle/encrypted_predicate.cpp
+++ b/src/mongo/db/query/fle/encrypted_predicate.cpp
@@ -72,5 +72,20 @@ std::vector<Value> toValues(std::vector<PrfBlock>&& vec) {
}
return output;
}
+
+std::unique_ptr<Expression> makeTagDisjunction(ExpressionContext* expCtx,
+ std::vector<Value>&& tags) {
+ std::vector<boost::intrusive_ptr<Expression>> orListElems;
+ for (auto&& tagElt : tags) {
+ // ... and for each tag, construct expression {$in: [tag,
+ // "$__safeContent__"]}.
+ std::vector<boost::intrusive_ptr<Expression>> inVec{
+ ExpressionConstant::create(expCtx, tagElt),
+ ExpressionFieldPath::createPathFromString(
+ expCtx, kSafeContent, expCtx->variablesParseState)};
+ orListElems.push_back(make_intrusive<ExpressionIn>(expCtx, std::move(inVec)));
+ }
+ return std::make_unique<ExpressionOr>(expCtx, std::move(orListElems));
+}
} // namespace fle
} // namespace mongo
diff --git a/src/mongo/db/query/fle/encrypted_predicate.h b/src/mongo/db/query/fle/encrypted_predicate.h
index d6c379f6350..5fcac966dbf 100644
--- a/src/mongo/db/query/fle/encrypted_predicate.h
+++ b/src/mongo/db/query/fle/encrypted_predicate.h
@@ -72,6 +72,9 @@ T parseFindPayload(BSONValue payload) {
payload);
}
+std::unique_ptr<Expression> makeTagDisjunction(ExpressionContext* expCtx,
+ std::vector<Value>&& tags);
+
/**
* Convert a vector of PrfBlocks to a BSONArray for use in MatchExpression tag generation.
*/
diff --git a/src/mongo/db/query/fle/encrypted_predicate_test_fixtures.h b/src/mongo/db/query/fle/encrypted_predicate_test_fixtures.h
index 88a6c85983a..5c06e89a0da 100644
--- a/src/mongo/db/query/fle/encrypted_predicate_test_fixtures.h
+++ b/src/mongo/db/query/fle/encrypted_predicate_test_fixtures.h
@@ -102,5 +102,6 @@ public:
protected:
MockServerRewrite _mock{};
+ ExpressionContextForTest _expCtx;
};
} // namespace mongo::fle
diff --git a/src/mongo/db/query/fle/equality_predicate.cpp b/src/mongo/db/query/fle/equality_predicate.cpp
index a349962145b..88396949ee8 100644
--- a/src/mongo/db/query/fle/equality_predicate.cpp
+++ b/src/mongo/db/query/fle/equality_predicate.cpp
@@ -286,20 +286,8 @@ std::unique_ptr<Expression> EqualityPredicate::rewriteToTagDisjunction(Expressio
std::vector<boost::intrusive_ptr<Expression>> orListElems;
auto payload = constChild->getValue();
auto tags = toValues(generateTags(std::ref(payload)));
- for (auto&& tagElt : tags) {
- // ... and for each tag, construct expression {$in: [tag,
- // "$__safeContent__"]}.
- std::vector<boost::intrusive_ptr<Expression>> inVec{
- ExpressionConstant::create(_rewriter->getExpressionContext(), tagElt),
- ExpressionFieldPath::createPathFromString(
- _rewriter->getExpressionContext(),
- kSafeContent,
- _rewriter->getExpressionContext()->variablesParseState)};
- orListElems.push_back(
- make_intrusive<ExpressionIn>(_rewriter->getExpressionContext(), std::move(inVec)));
- }
- auto disjunction = std::make_unique<ExpressionOr>(_rewriter->getExpressionContext(),
- std::move(orListElems));
+ auto disjunction = makeTagDisjunction(_rewriter->getExpressionContext(), std::move(tags));
+
if (eqExpr->getOp() == ExpressionCompare::NE) {
std::vector<boost::intrusive_ptr<Expression>> notChild{disjunction.release()};
return std::make_unique<ExpressionNot>(_rewriter->getExpressionContext(),
@@ -312,27 +300,19 @@ std::unique_ptr<Expression> EqualityPredicate::rewriteToTagDisjunction(Expressio
return nullptr;
}
auto& equalitiesList = inList->getChildren();
- std::vector<boost::intrusive_ptr<Expression>> orListElems;
- auto expCtx = _rewriter->getExpressionContext();
+ std::vector<Value> allTags;
for (auto& equality : equalitiesList) {
// For each expression representing a FleFindPayload...
if (auto constChild = dynamic_cast<ExpressionConstant*>(equality.get())) {
// ... rewrite the payload to a list of tags...
auto payload = constChild->getValue();
auto tags = toValues(generateTags(std::ref(payload)));
- for (auto&& tagElt : tags) {
- // ... and for each tag, construct expression {$in: [tag,
- // "$__safeContent__"]}.
- std::vector<boost::intrusive_ptr<Expression>> inVec{
- ExpressionConstant::create(expCtx, tagElt),
- ExpressionFieldPath::createPathFromString(
- expCtx, kSafeContent, expCtx->variablesParseState)};
- orListElems.push_back(
- make_intrusive<ExpressionIn>(expCtx, std::move(inVec)));
- }
+ allTags.insert(allTags.end(),
+ std::make_move_iterator(tags.begin()),
+ std::make_move_iterator(tags.end()));
}
}
- return std::make_unique<ExpressionOr>(expCtx, std::move(orListElems));
+ return makeTagDisjunction(_rewriter->getExpressionContext(), std::move(allTags));
}
return nullptr;
}
diff --git a/src/mongo/db/query/fle/range_predicate.cpp b/src/mongo/db/query/fle/range_predicate.cpp
index 0ebf33790fc..184f4f4baa2 100644
--- a/src/mongo/db/query/fle/range_predicate.cpp
+++ b/src/mongo/db/query/fle/range_predicate.cpp
@@ -35,6 +35,7 @@
#include "mongo/crypto/fle_crypto.h"
#include "mongo/crypto/fle_tags.h"
#include "mongo/db/matcher/expression_leaf.h"
+#include "mongo/db/pipeline/expression.h"
#include "mongo/db/query/fle/encrypted_predicate.h"
namespace mongo::fle {
@@ -42,6 +43,9 @@ namespace mongo::fle {
REGISTER_ENCRYPTED_MATCH_PREDICATE_REWRITE_WITH_FLAG(BETWEEN,
RangePredicate,
gFeatureFlagFLE2Range);
+REGISTER_ENCRYPTED_AGG_PREDICATE_REWRITE_WITH_FLAG(ExpressionBetween,
+ RangePredicate,
+ gFeatureFlagFLE2Range);
std::vector<PrfBlock> RangePredicate::generateTags(BSONValue payload) const {
auto parsedPayload = parseFindPayload<ParsedFindRangePayload>(payload);
@@ -62,7 +66,9 @@ std::vector<PrfBlock> RangePredicate::generateTags(BSONValue payload) const {
std::unique_ptr<MatchExpression> RangePredicate::rewriteToTagDisjunction(
MatchExpression* expr) const {
- invariant(expr->matchType() == MatchExpression::BETWEEN);
+ tassert(6720900,
+ "Range rewrite should only be called with $between operator.",
+ expr->matchType() == MatchExpression::BETWEEN);
auto betExpr = static_cast<BetweenMatchExpression*>(expr);
auto payload = betExpr->rhs();
@@ -72,9 +78,24 @@ std::unique_ptr<MatchExpression> RangePredicate::rewriteToTagDisjunction(
return makeTagDisjunction(toBSONArray(generateTags(payload)));
}
-// TODO: SERVER-67209 Server-side rewrite for agg expressions with $between.
std::unique_ptr<Expression> RangePredicate::rewriteToTagDisjunction(Expression* expr) const {
- return nullptr;
+ auto betweenExpr = dynamic_cast<ExpressionBetween*>(expr);
+ tassert(6720901, "Range rewrite should only be called with $between operator.", betweenExpr);
+ auto children = betweenExpr->getChildren();
+ uassert(6720902, "$between should have two children.", children.size() == 2);
+
+ auto fieldpath = dynamic_cast<ExpressionFieldPath*>(children[0].get());
+ uassert(6720903, "first argument should be a fieldpath", fieldpath);
+ auto secondArg = dynamic_cast<ExpressionConstant*>(children[1].get());
+ uassert(6720904, "second argument should be a constant", secondArg);
+ auto payload = secondArg->getValue();
+
+ if (!isPayload(payload)) {
+ return nullptr;
+ }
+ auto tags = toValues(generateTags(std::ref(payload)));
+
+ return makeTagDisjunction(_rewriter->getExpressionContext(), std::move(tags));
}
// TODO: SERVER-67267 Rewrite $between to $_internalFleBetween when number of tags exceeds
diff --git a/src/mongo/db/query/fle/range_predicate_test.cpp b/src/mongo/db/query/fle/range_predicate_test.cpp
index bc8bca184f3..17f33b55e52 100644
--- a/src/mongo/db/query/fle/range_predicate_test.cpp
+++ b/src/mongo/db/query/fle/range_predicate_test.cpp
@@ -29,6 +29,9 @@
#include "mongo/crypto/fle_crypto.h"
#include "mongo/db/matcher/expression_leaf.h"
+#include "mongo/db/pipeline/expression.h"
+#include "mongo/db/pipeline/expression_context_for_test.h"
+#include "mongo/db/query/fle/encrypted_predicate.h"
#include "mongo/db/query/fle/encrypted_predicate_test_fixtures.h"
#include "mongo/db/query/fle/range_predicate.h"
#include "mongo/idl/server_parameter_test_util.h"
@@ -51,39 +54,52 @@ public:
}
+ bool payloadValid = true;
+
protected:
bool isPayload(const BSONElement& elt) const override {
- return true;
+ return payloadValid;
}
bool isPayload(const Value& v) const override {
- return true;
+ return payloadValid;
}
std::vector<PrfBlock> generateTags(BSONValue payload) const {
return stdx::visit(
- OverloadedVisitor{
- [&](BSONElement p) {
- auto parsedPayload = p.Obj().firstElement();
- auto fieldName = parsedPayload.fieldNameStringData();
-
- std::vector<BSONElement> range;
- auto payloadAsArray = parsedPayload.Array();
- for (auto&& elt : payloadAsArray) {
- range.push_back(elt);
- }
-
- std::vector<PrfBlock> allTags;
- for (auto i = range[0].Number(); i <= range[1].Number(); i++) {
- ASSERT(_tags.find({fieldName, i}) != _tags.end());
- auto temp = _tags.find({fieldName, i})->second;
- for (auto tag : temp) {
- allTags.push_back(tag);
- }
- }
- return allTags;
- },
- [&](std::reference_wrapper<Value> v) { return std::vector<PrfBlock>{}; }},
+ OverloadedVisitor{[&](BSONElement p) {
+ auto parsedPayload = p.Obj().firstElement();
+ auto fieldName = parsedPayload.fieldNameStringData();
+
+ std::vector<BSONElement> range;
+ auto payloadAsArray = parsedPayload.Array();
+ for (auto&& elt : payloadAsArray) {
+ range.push_back(elt);
+ }
+
+ std::vector<PrfBlock> allTags;
+ for (auto i = range[0].Number(); i <= range[1].Number(); i++) {
+ ASSERT(_tags.find({fieldName, i}) != _tags.end());
+ auto temp = _tags.find({fieldName, i})->second;
+ for (auto tag : temp) {
+ allTags.push_back(tag);
+ }
+ }
+ return allTags;
+ },
+ [&](std::reference_wrapper<Value> v) {
+ if (v.get().isArray()) {
+ auto arr = v.get().getArray();
+ std::vector<PrfBlock> allTags;
+ for (auto& val : arr) {
+ allTags.push_back(PrfBlock(
+ {static_cast<unsigned char>(val.coerceToInt())}));
+ }
+ return allTags;
+ } else {
+ return std::vector<PrfBlock>{};
+ }
+ }},
payload);
}
@@ -99,7 +115,7 @@ protected:
MockRangePredicate _predicate;
};
-TEST_F(RangePredicateRewriteTest, BasicRangeRewrite) {
+TEST_F(RangePredicateRewriteTest, MatchRangeRewrite) {
RAIIServerParameterControllerForTest controller("featureFlagFLE2Range", true);
int start = 1;
@@ -126,5 +142,31 @@ TEST_F(RangePredicateRewriteTest, BasicRangeRewrite) {
assertRewriteToTags(_predicate, &inputExpr, toBSONArray(std::move(allTags)));
}
+
+TEST_F(RangePredicateRewriteTest, AggRangeRewrite) {
+ auto input = fromjson(R"({$between: ["$age", {$literal: [1, 2, 3]}]})");
+ auto inputExpr =
+ ExpressionBetween::parseExpression(&_expCtx, input, _expCtx.variablesParseState);
+
+ auto expected = makeTagDisjunction(&_expCtx, toValues({{1}, {2}, {3}}));
+
+ auto actual = _predicate.rewrite(inputExpr.get());
+
+ ASSERT_BSONOBJ_EQ(actual->serialize(false).getDocument().toBson(),
+ expected->serialize(false).getDocument().toBson());
+}
+
+TEST_F(RangePredicateRewriteTest, AggRangeRewriteNoOp) {
+ auto input = fromjson(R"({$between: ["$age", {$literal: [1, 2, 3]}]})");
+ auto inputExpr =
+ ExpressionBetween::parseExpression(&_expCtx, input, _expCtx.variablesParseState);
+
+ auto expected = inputExpr;
+
+ _predicate.payloadValid = false;
+ auto actual = _predicate.rewrite(inputExpr.get());
+ ASSERT(actual == nullptr);
+}
+
}; // namespace
} // namespace mongo::fle