diff options
Diffstat (limited to 'src/mongo/db/query/fle/range_predicate.cpp')
-rw-r--r-- | src/mongo/db/query/fle/range_predicate.cpp | 128 |
1 files changed, 86 insertions, 42 deletions
diff --git a/src/mongo/db/query/fle/range_predicate.cpp b/src/mongo/db/query/fle/range_predicate.cpp index 8c083b07dc1..dcd71777782 100644 --- a/src/mongo/db/query/fle/range_predicate.cpp +++ b/src/mongo/db/query/fle/range_predicate.cpp @@ -55,6 +55,77 @@ REGISTER_ENCRYPTED_AGG_PREDICATE_REWRITE_WITH_FLAG(ExpressionBetween, RangePredicate, gFeatureFlagFLE2Range); +REGISTER_ENCRYPTED_AGG_PREDICATE_REWRITE_WITH_FLAG(ExpressionCompare, + RangePredicate, + gFeatureFlagFLE2Range); + +namespace { +// Validate the range operator passed in and return the fieldpath and payload for the rewrite. If +// the passed-in expression is a comparison with $eq, $ne, or $cmp, none of which represent a range +// predicate, then return null to the caller so that the rewrite can return null. +std::pair<boost::intrusive_ptr<Expression>, Value> validateRangeOp(Expression* expr) { + auto children = [&]() { + if (auto betweenExpr = dynamic_cast<ExpressionBetween*>(expr)) { + return betweenExpr->getChildren(); + } else { + auto cmpExpr = dynamic_cast<ExpressionCompare*>(expr); + tassert(6720901, + "Range rewrite should only be called with $between or comparison operator.", + cmpExpr); + switch (cmpExpr->getOp()) { + case ExpressionCompare::GT: + case ExpressionCompare::GTE: + case ExpressionCompare::LT: + case ExpressionCompare::LTE: + return cmpExpr->getChildren(); + + case ExpressionCompare::EQ: + case ExpressionCompare::NE: + case ExpressionCompare::CMP: + return std::vector<boost::intrusive_ptr<Expression>>(); + } + } + return std::vector<boost::intrusive_ptr<Expression>>(); + }(); + if (children.empty()) { + return {nullptr, Value()}; + } + // Both ExpressionBetween and ExpressionCompare have a fixed arity of 2. + auto fieldpath = dynamic_cast<ExpressionFieldPath*>(children[0].get()); + uassert(6720903, "first argument should be a fieldpath", fieldpath); + auto secondArg = dynamic_cast<ExpressionConstant*>(children[1].get()); + uassert(6720904, "second argument should be a constant", secondArg); + auto payload = secondArg->getValue(); + return {children[0], payload}; +} +} // namespace + +std::unique_ptr<ExpressionInternalFLEBetween> RangePredicate::fleBetweenFromPayload( + StringData path, ParsedFindRangePayload payload) const { + auto* expCtx = _rewriter->getExpressionContext(); + return fleBetweenFromPayload(ExpressionFieldPath::createPathFromString( + expCtx, path.toString(), expCtx->variablesParseState), + payload); +} + +std::unique_ptr<ExpressionInternalFLEBetween> RangePredicate::fleBetweenFromPayload( + boost::intrusive_ptr<Expression> fieldpath, ParsedFindRangePayload payload) const { + tassert(7030501, + "$internalFleBetween can only be generated from a non-stub payload.", + !payload.isStub()); + auto cm = payload.maxCounter; + ServerDataEncryptionLevel1Token serverToken = std::move(payload.serverToken); + std::vector<ConstDataRange> edcTokens; + std::transform(std::make_move_iterator(payload.edges.value().begin()), + std::make_move_iterator(payload.edges.value().end()), + std::back_inserter(edcTokens), + [](FLEFindEdgeTokenSet&& edge) { return edge.edc.toCDR(); }); + + auto* expCtx = _rewriter->getExpressionContext(); + return std::make_unique<ExpressionInternalFLEBetween>( + expCtx, fieldpath, serverToken.toCDR(), cm, std::move(edcTokens)); +} + std::vector<PrfBlock> RangePredicate::generateTags(BSONValue payload) const { auto parsedPayload = parseFindPayload<ParsedFindRangePayload>(payload); std::vector<PrfBlock> tags; @@ -99,56 +170,23 @@ std::unique_ptr<MatchExpression> RangePredicate::rewriteToTagDisjunction( return makeTagDisjunction(toBSONArray(generateTags(payload))); } -std::pair<boost::intrusive_ptr<Expression>, Value> validateBetween(Expression* expr) { - auto betweenExpr = dynamic_cast<ExpressionBetween*>(expr); - tassert(6720901, "Range rewrite should only be called with $between operator.", betweenExpr); - auto children = betweenExpr->getChildren(); - uassert(6720902, "$between should have two children.", children.size() == 2); - - auto fieldpath = dynamic_cast<ExpressionFieldPath*>(children[0].get()); - uassert(6720903, "first argument should be a fieldpath", fieldpath); - auto secondArg = dynamic_cast<ExpressionConstant*>(children[1].get()); - uassert(6720904, "second argument should be a constant", secondArg); - auto payload = secondArg->getValue(); - return {children[0], payload}; -} - std::unique_ptr<Expression> RangePredicate::rewriteToTagDisjunction(Expression* expr) const { - auto [_, payload] = validateBetween(expr); + auto [fieldpath, payload] = validateRangeOp(expr); + if (!fieldpath) { + return nullptr; + } if (!isPayload(payload)) { return nullptr; } + if (isStub(std::ref(payload))) { + return std::make_unique<ExpressionConstant>(_rewriter->getExpressionContext(), Value(true)); + } + auto tags = toValues(generateTags(std::ref(payload))); return makeTagDisjunction(_rewriter->getExpressionContext(), std::move(tags)); } -std::unique_ptr<ExpressionInternalFLEBetween> RangePredicate::fleBetweenFromPayload( - StringData path, ParsedFindRangePayload payload) const { - auto* expCtx = _rewriter->getExpressionContext(); - return fleBetweenFromPayload(ExpressionFieldPath::createPathFromString( - expCtx, path.toString(), expCtx->variablesParseState), - payload); -} - -std::unique_ptr<ExpressionInternalFLEBetween> RangePredicate::fleBetweenFromPayload( - boost::intrusive_ptr<Expression> fieldpath, ParsedFindRangePayload payload) const { - tassert(7030501, - "$internalFleBetween can only be generated from a non-stub payload.", - !payload.isStub()); - auto cm = payload.maxCounter; - ServerDataEncryptionLevel1Token serverToken = std::move(payload.serverToken); - std::vector<ConstDataRange> edcTokens; - std::transform(std::make_move_iterator(payload.edges.value().begin()), - std::make_move_iterator(payload.edges.value().end()), - std::back_inserter(edcTokens), - [](FLEFindEdgeTokenSet&& edge) { return edge.edc.toCDR(); }); - - auto* expCtx = _rewriter->getExpressionContext(); - return std::make_unique<ExpressionInternalFLEBetween>( - expCtx, fieldpath, serverToken.toCDR(), cm, std::move(edcTokens)); -} - std::unique_ptr<MatchExpression> RangePredicate::rewriteToRuntimeComparison( MatchExpression* expr) const { BSONElement ffp; @@ -179,11 +217,17 @@ std::unique_ptr<MatchExpression> RangePredicate::rewriteToRuntimeComparison( } std::unique_ptr<Expression> RangePredicate::rewriteToRuntimeComparison(Expression* expr) const { - auto [fieldpath, ffp] = validateBetween(expr); + auto [fieldpath, ffp] = validateRangeOp(expr); + if (!fieldpath) { + return nullptr; + } if (!isPayload(ffp)) { return nullptr; } auto payload = parseFindPayload<ParsedFindRangePayload>(ffp); + if (payload.isStub()) { + return std::make_unique<ExpressionConstant>(_rewriter->getExpressionContext(), Value(true)); + } return fleBetweenFromPayload(fieldpath, payload); } } // namespace mongo::fle |