diff options
author | Jacob Evans <jacob.evans@10gen.com> | 2021-02-22 14:56:47 -0500 |
---|---|---|
committer | Evergreen Agent <no-reply@evergreen.mongodb.com> | 2021-02-24 04:47:18 +0000 |
commit | 94bbe88ad4126b94900060c68c144906c7c586a6 (patch) | |
tree | d7d6b7a0f7eb393bf09e3b87949ed188e68a0912 /src/mongo | |
parent | d32cb4b3d8f475caee2b2458ff65e5aca3530b41 (diff) | |
download | mongo-94bbe88ad4126b94900060c68c144906c7c586a6.tar.gz |
SERVER-54709 Add flexible agg expression walker
Diffstat (limited to 'src/mongo')
-rw-r--r-- | src/mongo/db/pipeline/expression_walker.h | 163 | ||||
-rw-r--r-- | src/mongo/db/pipeline/expression_walker_test.cpp | 50 |
2 files changed, 191 insertions, 22 deletions
diff --git a/src/mongo/db/pipeline/expression_walker.h b/src/mongo/db/pipeline/expression_walker.h index 3222d02ceb9..9718a2f0205 100644 --- a/src/mongo/db/pipeline/expression_walker.h +++ b/src/mongo/db/pipeline/expression_walker.h @@ -1,5 +1,5 @@ /** - * Copyright (C) 2019-present MongoDB, Inc. + * Copyright (C) 2021-present MongoDB, Inc. * * This program is free software: you can redistribute it and/or modify * it under the terms of the Server Side Public License, version 1, @@ -29,38 +29,173 @@ #pragma once -#include <boost/intrusive_ptr.hpp> +#include <memory> +#include <type_traits> +#include <utility> #include "mongo/db/pipeline/expression.h" +#include "mongo/stdx/type_traits.h" namespace mongo::expression_walker { +// The following types and constexpr values are used to determine if a Walker has a given member +// function at compile-time. + +/** + * PreVisit provides the compiler with a type for a preVisit member function. + */ +template <typename Walker, typename Arg> +using PreVisit = decltype(std::declval<Walker>().preVisit(std::declval<Arg>())); +/** + * hasVoidPreVisit is a template variable indicating whether such a void-returning member function + * exists for a given Walker type when called on a pointer to our Expression type. + */ +template <typename Walker> +constexpr auto hasVoidPreVisit = stdx::is_detected_exact_v<void, PreVisit, Walker, Expression*>; +/** + * hasVoidPreVisit is a template variable indicating whether such a pointer-returning member + * function exists for a given Walker type when called on a pointer to our Expression type. + */ +template <typename Walker> +constexpr auto hasPtrPreVisit = + stdx::is_detected_convertible_v<std::unique_ptr<Expression>, PreVisit, Walker, Expression*>; + +/** + * InVisit provides the compiler with a type for an inVisit member function. + */ +template <typename Walker, typename... Args> +using InVisit = decltype(std::declval<Walker>().inVisit(std::declval<Args>()...)); +/** + * hasBasicInVisit is a template variable indicating whether such a member function exists for a + * given Walker type when called on a pointer to our Expression type. + */ +template <typename Walker> +constexpr auto hasBasicInVisit = stdx::is_detected_v<InVisit, Walker, Expression*>; +/** + * hasCountingInVisit is a template variable indicating whether such a member function exists for a + * given Walker type when called on a pointer to our Expression type. + */ +template <typename Walker> +constexpr auto hasCountingInVisit = + stdx::is_detected_v<InVisit, Walker, unsigned long long, Expression*>; + /** - * Provided with a Walker and an Expression, walk() calls each of the following: + * PostVisit provides the compiler with a type for a postVisit member function. + */ +template <typename Walker, typename Arg> +using PostVisit = decltype(std::declval<Walker>().postVisit(std::declval<Arg>())); +/** + * hasVoidPostVisit is a template variable indicating whether such a void-returning member function + * exists for a given Walker type when called on a pointer to our Expression type. + */ +template <typename Walker> +constexpr auto hasVoidPostVisit = stdx::is_detected_exact_v<void, PostVisit, Walker, Expression*>; +/** + * hasVoidPostVisit is a template variable indicating whether such a pointer-returning member + * function exists for a given Walker type when called on a pointer to our Expression type. + */ +template <typename Walker> +constexpr auto hasPtrPostVisit = + stdx::is_detected_convertible_v<std::unique_ptr<Expression>, PostVisit, Walker, Expression*>; + +/** + * hasReturningVisit is a template variable indicating whether there is a pointer-returning member + * function (pre or post) that exists for a given Walker type when called on a pointer to our + * Expression type. + */ +template <typename Walker> +constexpr auto hasReturningVisit = hasPtrPreVisit<Walker> || hasPtrPostVisit<Walker>; + +/** + * Provided with a Walker and an Expression, walk() calls each of the following if they exist: * * walker.preVisit() once before walking to each child. * * walker.inVisit() between walking to each child. It is called multiple times, once between each * pair of children. walker.inVisit() is skipped if the Expression has fewer than two children. * * walker.postVisit() once after walking to each child. * Each of the Expression's child Expressions is recursively walked and the same three methods are - * called for it. + * called for it. Although each of the methods are individually optional, at least one of them must + * exist. preVisit() and postVisit() may return a pointer to an Expression. If either does, walk() + * will replace the current Expression with the return value. If no change is needed during a + * particular call, preVisit() and postVisit() may return null. walk() returns a unique_ptr + * containing a new root node, if it is modified by a value returning preVisit() or postVisit(), + * nullptr if it is not modified or void if modification is impossible for the given Walker. */ template <typename Walker> -void walk(Walker* walker, Expression* expression) { +auto walk(Walker* walker, Expression* expression) + -> std::conditional_t<hasReturningVisit<Walker>, std::unique_ptr<Expression>, void> { + static_assert(hasVoidPreVisit<Walker> || hasPtrPreVisit<Walker> || hasBasicInVisit<Walker> || + hasCountingInVisit<Walker> || hasVoidPostVisit<Walker> || + hasPtrPostVisit<Walker>, + "Walker must have at least one of the following functions: 'preVisit', " + "'inVisit', 'postVisit'."); + static_assert(!hasBasicInVisit<Walker> || !hasCountingInVisit<Walker>, + "Walker must include only one signature for inVisit: inVisit(num, expression) " + "or inVisit(expression)."); + // Calls walk on a child node. Then replaces that node if walk returns a non-null value. + auto walkChild = [&](auto&& child) { + if constexpr (hasReturningVisit<Walker>) { + if (auto newChild = walk(walker, child.get())) + child = newChild.release(); + } else { + walk(walker, child.get()); + } + }; + + auto newExpr = std::unique_ptr<Expression>{}; if (expression) { - walker->preVisit(expression); + if constexpr (hasVoidPreVisit<Walker>) { + walker->preVisit(expression); + } else if constexpr (hasPtrPreVisit<Walker>) { + newExpr = walker->preVisit(expression); + expression = newExpr.get() != nullptr ? newExpr.get() : expression; + } // InVisit needs to be called between every two nodes which requires more complicated - // branching logic. - auto count = 0ull; - for (auto&& child : expression->getChildren()) { - if (count != 0ull) - walker->inVisit(count, expression); - ++count; - walk(walker, child.get()); + // branching logic. InVisit is forbidden from replacing its Expression through the return + // value and must return void since it would break our iteration and be confusing to + // replace a node while only a portion of its children have been walked. + if constexpr (hasBasicInVisit<Walker>) { + static_assert( + std::is_void_v<InVisit<Walker, Expression*>>, + "Walker::inVisit must return void. Modification is forbidden between walking " + "children."); + auto skippingFirst = true; + for (auto&& child : expression->getChildren()) { + if (skippingFirst) + skippingFirst = false; + else + walker->inVisit(expression); + if (auto newChild = walk(walker, child.get())) + child = std::move(newChild); + } + } + // If the signature of InVisit includes a count, maintaing it while walking and pass it to + // the function. + else if constexpr (hasCountingInVisit<Walker>) { + static_assert( + std::is_void_v<InVisit<Walker, unsigned long long, Expression*>>, + "Walker::inVisit must return void. Modification is forbidden between walking " + "children."); + auto count = 0ull; + for (auto&& child : expression->getChildren()) { + if (count != 0ull) + walker->inVisit(count, expression); + count++; + walkChild(child); + } + } else { + for (auto&& child : expression->getChildren()) + walkChild(child); } - walker->postVisit(expression); + if constexpr (hasVoidPostVisit<Walker>) + walker->postVisit(expression); + else if constexpr (hasPtrPostVisit<Walker>) + if (auto postResult = walker->postVisit(expression)) + newExpr = std::move(postResult); } + if constexpr (hasReturningVisit<Walker>) + return newExpr; } } // namespace mongo::expression_walker diff --git a/src/mongo/db/pipeline/expression_walker_test.cpp b/src/mongo/db/pipeline/expression_walker_test.cpp index 447ae1e3c9a..c3a59dd82e4 100644 --- a/src/mongo/db/pipeline/expression_walker_test.cpp +++ b/src/mongo/db/pipeline/expression_walker_test.cpp @@ -29,11 +29,14 @@ #include "mongo/platform/basic.h" +#include <memory> #include <string> +#include <type_traits> #include <vector> #include "mongo/base/string_data.h" #include "mongo/bson/json.h" +#include "mongo/db/exec/document_value/document_value_test_util.h" #include "mongo/db/pipeline/aggregate_command_gen.h" #include "mongo/db/pipeline/aggregation_context_fixture.h" #include "mongo/db/pipeline/aggregation_request_helper.h" @@ -52,9 +55,9 @@ protected: ASSERT_EQUALS(inputBson["pipeline"].type(), BSONType::Array); auto rawPipeline = parsePipelineFromBSON(inputBson["pipeline"]); NamespaceString testNss("test", "collection"); - AggregateCommand request(testNss, rawPipeline); + auto command = AggregateCommand{testNss, rawPipeline}; - return Pipeline::parse(request.getPipeline(), getExpCtx()); + return Pipeline::parse(command.getPipeline(), getExpCtx()); } auto parseExpression(std::string expressionString) { @@ -66,13 +69,12 @@ protected: using namespace std::string_literals; using namespace expression_walker; -TEST_F(ExpressionWalkerTest, NullTreeWalkSucceeds) { +TEST_F(ExpressionWalkerTest, NothingTreeWalkSucceedsAndReturnsVoid) { struct { - void preVisit(Expression*) {} - void inVisit(unsigned long long, Expression*) {} void postVisit(Expression*) {} } nothingWalker; - auto expression = boost::intrusive_ptr<Expression>(); + auto expression = std::unique_ptr<Expression>{}; + static_assert(std::is_same_v<decltype(walk(¬hingWalker, expression.get())), void>); walk(¬hingWalker, expression.get()); } @@ -99,15 +101,47 @@ TEST_F(ExpressionWalkerTest, PrintWalkReflectsMutation) { auto expression = parseExpression(expressionString); walk(&stringWalker, expression.get()); ASSERT_EQ(stringWalker.string, expressionString); + + struct { + auto preVisit(Expression* expression) { + if (auto constant = dynamic_cast<ExpressionConstant*>(expression)) + if (constant->getValue().getString() == "black") + return std::make_unique<ExpressionConstant>(expCtx, Value{"white"s}); + return std::unique_ptr<ExpressionConstant>{}; + } + ExpressionContext* const expCtx; + } whiteWalker{getExpCtxRaw()}; + + ASSERT_FALSE(walk(&whiteWalker, expression.get())); + stringWalker.string.clear(); + walk(&stringWalker, expression.get()); + ASSERT_EQ(stringWalker.string, "{$concat: [\"white\", \"green\", \"yellow\"]}"s); +} + +TEST_F(ExpressionWalkerTest, RootNodeReplacable) { + struct { + auto postVisit(Expression* expression) { + return std::make_unique<ExpressionConstant>(expCtx, Value{"soup"s}); + } + ExpressionContext* const expCtx; + } replaceWithSoup{getExpCtxRaw()}; + + auto expressionString = "{$add: [2, 3, 4, {$atan2: [1, 0]}]}"s; + auto expression = parseExpression(expressionString); + auto resultExpression = walk(&replaceWithSoup, expression.get()); + ASSERT_VALUE_EQ(dynamic_cast<ExpressionConstant*>(resultExpression.get())->getValue(), + Value{"soup"s}); + // The input Expression, as a side effect, will have all its branches changed to soup by this + // rewrite. + for (auto&& child : dynamic_cast<ExpressionAdd*>(expression.get())->getChildren()) + ASSERT_VALUE_EQ(dynamic_cast<ExpressionConstant*>(child.get())->getValue(), Value{"soup"s}); } TEST_F(ExpressionWalkerTest, InVisitCanCount) { struct { - void preVisit(Expression*) {} void inVisit(unsigned long long count, Expression*) { counter.push_back(count); } - void postVisit(Expression*) {} std::vector<unsigned long long> counter; } countWalker; |