diff options
Diffstat (limited to 'src/mongo/db/query/fle/encrypted_predicate.h')
| -rw-r--r-- | src/mongo/db/query/fle/encrypted_predicate.h | 67 |
1 files changed, 38 insertions, 29 deletions
diff --git a/src/mongo/db/query/fle/encrypted_predicate.h b/src/mongo/db/query/fle/encrypted_predicate.h index 5fcac966dbf..fc78aa28014 100644 --- a/src/mongo/db/query/fle/encrypted_predicate.h +++ b/src/mongo/db/query/fle/encrypted_predicate.h @@ -40,6 +40,7 @@ #include "mongo/db/matcher/expression_leaf.h" #include "mongo/db/pipeline/expression.h" #include "mongo/db/query/fle/query_rewriter_interface.h" +#include "mongo/stdx/unordered_map.h" /** * This file contains an abstract class that describes rewrites on agg Expressions and @@ -194,33 +195,37 @@ private: * are keyed on the dynamic type for the Expression subclass. */ -using ExpressionToRewriteMap = stdx::unordered_map< - std::type_index, - std::function<std::unique_ptr<Expression>(QueryRewriterInterface*, Expression*)>>; +using ExpressionRewriteFunction = + std::function<std::unique_ptr<Expression>(QueryRewriterInterface*, Expression*)>; +using ExpressionToRewriteMap = + stdx::unordered_map<std::type_index, std::vector<ExpressionRewriteFunction>>; extern ExpressionToRewriteMap aggPredicateRewriteMap; -using MatchTypeToRewriteMap = stdx::unordered_map< - MatchExpression::MatchType, - std::function<std::unique_ptr<MatchExpression>(QueryRewriterInterface*, MatchExpression*)>>; +using MatchRewriteFunction = + std::function<std::unique_ptr<MatchExpression>(QueryRewriterInterface*, MatchExpression*)>; +using MatchTypeToRewriteMap = + stdx::unordered_map<MatchExpression::MatchType, std::vector<MatchRewriteFunction>>; extern MatchTypeToRewriteMap matchPredicateRewriteMap; /** * Register an agg rewrite if a condition is true at startup time. */ -#define REGISTER_ENCRYPTED_AGG_PREDICATE_REWRITE_GUARDED(className, rewriteClass, isEnabledExpr) \ - MONGO_INITIALIZER(encryptedAggPredicateRewriteFor_##className)(InitializerContext*) { \ - \ - invariant(aggPredicateRewriteMap.find(typeid(className)) == aggPredicateRewriteMap.end()); \ - aggPredicateRewriteMap[typeid(className)] = [&](auto* rewriter, auto* expr) { \ - if (isEnabledExpr) { \ - return rewriteClass{rewriter}.rewrite(expr); \ - } else { \ - return std::unique_ptr<Expression>(nullptr); \ - } \ - }; \ - } +#define REGISTER_ENCRYPTED_AGG_PREDICATE_REWRITE_GUARDED(className, rewriteClass, isEnabledExpr) \ + MONGO_INITIALIZER(encryptedAggPredicateRewriteFor_##className##_##rewriteClass) \ + (InitializerContext*) { \ + if (aggPredicateRewriteMap.find(typeid(className)) == aggPredicateRewriteMap.end()) { \ + aggPredicateRewriteMap[typeid(className)] = std::vector<ExpressionRewriteFunction>(); \ + } \ + aggPredicateRewriteMap[typeid(className)].push_back([](auto* rewriter, auto* expr) { \ + if (isEnabledExpr) { \ + return rewriteClass{rewriter}.rewrite(expr); \ + } else { \ + return std::unique_ptr<Expression>(nullptr); \ + } \ + }); \ + }; /** * Register an agg rewrite unconditionally. @@ -239,17 +244,21 @@ extern MatchTypeToRewriteMap matchPredicateRewriteMap; * Register a MatchExpression rewrite if a condition is true at startup time. */ #define REGISTER_ENCRYPTED_MATCH_PREDICATE_REWRITE_GUARDED(matchType, rewriteClass, isEnabledExpr) \ - MONGO_INITIALIZER(encryptedMatchPredicateRewriteFor_##matchType)(InitializerContext*) { \ - \ - invariant(matchPredicateRewriteMap.find(MatchExpression::matchType) == \ - matchPredicateRewriteMap.end()); \ - matchPredicateRewriteMap[MatchExpression::matchType] = [&](auto* rewriter, auto* expr) { \ - if (isEnabledExpr) { \ - return rewriteClass{rewriter}.rewrite(expr); \ - } else { \ - return std::unique_ptr<MatchExpression>(nullptr); \ - } \ - }; \ + MONGO_INITIALIZER(encryptedMatchPredicateRewriteFor_##matchType##_##rewriteClass) \ + (InitializerContext*) { \ + if (matchPredicateRewriteMap.find(MatchExpression::matchType) == \ + matchPredicateRewriteMap.end()) { \ + matchPredicateRewriteMap[MatchExpression::matchType] = \ + std::vector<MatchRewriteFunction>(); \ + } \ + matchPredicateRewriteMap[MatchExpression::matchType].push_back( \ + [](auto* rewriter, auto* expr) { \ + if (isEnabledExpr) { \ + return rewriteClass{rewriter}.rewrite(expr); \ + } else { \ + return std::unique_ptr<MatchExpression>(nullptr); \ + } \ + }); \ }; /** * Register a MatchExpression rewrite unconditionally. |
