summaryrefslogtreecommitdiff
path: root/src/mongo
diff options
context:
space:
mode:
authorJacob Evans <jacob.evans@10gen.com>2021-02-22 14:56:47 -0500
committerEvergreen Agent <no-reply@evergreen.mongodb.com>2021-02-24 04:47:18 +0000
commit94bbe88ad4126b94900060c68c144906c7c586a6 (patch)
treed7d6b7a0f7eb393bf09e3b87949ed188e68a0912 /src/mongo
parentd32cb4b3d8f475caee2b2458ff65e5aca3530b41 (diff)
downloadmongo-94bbe88ad4126b94900060c68c144906c7c586a6.tar.gz
SERVER-54709 Add flexible agg expression walker
Diffstat (limited to 'src/mongo')
-rw-r--r--src/mongo/db/pipeline/expression_walker.h163
-rw-r--r--src/mongo/db/pipeline/expression_walker_test.cpp50
2 files changed, 191 insertions, 22 deletions
diff --git a/src/mongo/db/pipeline/expression_walker.h b/src/mongo/db/pipeline/expression_walker.h
index 3222d02ceb9..9718a2f0205 100644
--- a/src/mongo/db/pipeline/expression_walker.h
+++ b/src/mongo/db/pipeline/expression_walker.h
@@ -1,5 +1,5 @@
/**
- * Copyright (C) 2019-present MongoDB, Inc.
+ * Copyright (C) 2021-present MongoDB, Inc.
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the Server Side Public License, version 1,
@@ -29,38 +29,173 @@
#pragma once
-#include <boost/intrusive_ptr.hpp>
+#include <memory>
+#include <type_traits>
+#include <utility>
#include "mongo/db/pipeline/expression.h"
+#include "mongo/stdx/type_traits.h"
namespace mongo::expression_walker {
+// The following types and constexpr values are used to determine if a Walker has a given member
+// function at compile-time.
+
+/**
+ * PreVisit provides the compiler with a type for a preVisit member function.
+ */
+template <typename Walker, typename Arg>
+using PreVisit = decltype(std::declval<Walker>().preVisit(std::declval<Arg>()));
+/**
+ * hasVoidPreVisit is a template variable indicating whether such a void-returning member function
+ * exists for a given Walker type when called on a pointer to our Expression type.
+ */
+template <typename Walker>
+constexpr auto hasVoidPreVisit = stdx::is_detected_exact_v<void, PreVisit, Walker, Expression*>;
+/**
+ * hasVoidPreVisit is a template variable indicating whether such a pointer-returning member
+ * function exists for a given Walker type when called on a pointer to our Expression type.
+ */
+template <typename Walker>
+constexpr auto hasPtrPreVisit =
+ stdx::is_detected_convertible_v<std::unique_ptr<Expression>, PreVisit, Walker, Expression*>;
+
+/**
+ * InVisit provides the compiler with a type for an inVisit member function.
+ */
+template <typename Walker, typename... Args>
+using InVisit = decltype(std::declval<Walker>().inVisit(std::declval<Args>()...));
+/**
+ * hasBasicInVisit is a template variable indicating whether such a member function exists for a
+ * given Walker type when called on a pointer to our Expression type.
+ */
+template <typename Walker>
+constexpr auto hasBasicInVisit = stdx::is_detected_v<InVisit, Walker, Expression*>;
+/**
+ * hasCountingInVisit is a template variable indicating whether such a member function exists for a
+ * given Walker type when called on a pointer to our Expression type.
+ */
+template <typename Walker>
+constexpr auto hasCountingInVisit =
+ stdx::is_detected_v<InVisit, Walker, unsigned long long, Expression*>;
+
/**
- * Provided with a Walker and an Expression, walk() calls each of the following:
+ * PostVisit provides the compiler with a type for a postVisit member function.
+ */
+template <typename Walker, typename Arg>
+using PostVisit = decltype(std::declval<Walker>().postVisit(std::declval<Arg>()));
+/**
+ * hasVoidPostVisit is a template variable indicating whether such a void-returning member function
+ * exists for a given Walker type when called on a pointer to our Expression type.
+ */
+template <typename Walker>
+constexpr auto hasVoidPostVisit = stdx::is_detected_exact_v<void, PostVisit, Walker, Expression*>;
+/**
+ * hasVoidPostVisit is a template variable indicating whether such a pointer-returning member
+ * function exists for a given Walker type when called on a pointer to our Expression type.
+ */
+template <typename Walker>
+constexpr auto hasPtrPostVisit =
+ stdx::is_detected_convertible_v<std::unique_ptr<Expression>, PostVisit, Walker, Expression*>;
+
+/**
+ * hasReturningVisit is a template variable indicating whether there is a pointer-returning member
+ * function (pre or post) that exists for a given Walker type when called on a pointer to our
+ * Expression type.
+ */
+template <typename Walker>
+constexpr auto hasReturningVisit = hasPtrPreVisit<Walker> || hasPtrPostVisit<Walker>;
+
+/**
+ * Provided with a Walker and an Expression, walk() calls each of the following if they exist:
* * walker.preVisit() once before walking to each child.
* * walker.inVisit() between walking to each child. It is called multiple times, once between each
* pair of children. walker.inVisit() is skipped if the Expression has fewer than two children.
* * walker.postVisit() once after walking to each child.
* Each of the Expression's child Expressions is recursively walked and the same three methods are
- * called for it.
+ * called for it. Although each of the methods are individually optional, at least one of them must
+ * exist. preVisit() and postVisit() may return a pointer to an Expression. If either does, walk()
+ * will replace the current Expression with the return value. If no change is needed during a
+ * particular call, preVisit() and postVisit() may return null. walk() returns a unique_ptr
+ * containing a new root node, if it is modified by a value returning preVisit() or postVisit(),
+ * nullptr if it is not modified or void if modification is impossible for the given Walker.
*/
template <typename Walker>
-void walk(Walker* walker, Expression* expression) {
+auto walk(Walker* walker, Expression* expression)
+ -> std::conditional_t<hasReturningVisit<Walker>, std::unique_ptr<Expression>, void> {
+ static_assert(hasVoidPreVisit<Walker> || hasPtrPreVisit<Walker> || hasBasicInVisit<Walker> ||
+ hasCountingInVisit<Walker> || hasVoidPostVisit<Walker> ||
+ hasPtrPostVisit<Walker>,
+ "Walker must have at least one of the following functions: 'preVisit', "
+ "'inVisit', 'postVisit'.");
+ static_assert(!hasBasicInVisit<Walker> || !hasCountingInVisit<Walker>,
+ "Walker must include only one signature for inVisit: inVisit(num, expression) "
+ "or inVisit(expression).");
+ // Calls walk on a child node. Then replaces that node if walk returns a non-null value.
+ auto walkChild = [&](auto&& child) {
+ if constexpr (hasReturningVisit<Walker>) {
+ if (auto newChild = walk(walker, child.get()))
+ child = newChild.release();
+ } else {
+ walk(walker, child.get());
+ }
+ };
+
+ auto newExpr = std::unique_ptr<Expression>{};
if (expression) {
- walker->preVisit(expression);
+ if constexpr (hasVoidPreVisit<Walker>) {
+ walker->preVisit(expression);
+ } else if constexpr (hasPtrPreVisit<Walker>) {
+ newExpr = walker->preVisit(expression);
+ expression = newExpr.get() != nullptr ? newExpr.get() : expression;
+ }
// InVisit needs to be called between every two nodes which requires more complicated
- // branching logic.
- auto count = 0ull;
- for (auto&& child : expression->getChildren()) {
- if (count != 0ull)
- walker->inVisit(count, expression);
- ++count;
- walk(walker, child.get());
+ // branching logic. InVisit is forbidden from replacing its Expression through the return
+ // value and must return void since it would break our iteration and be confusing to
+ // replace a node while only a portion of its children have been walked.
+ if constexpr (hasBasicInVisit<Walker>) {
+ static_assert(
+ std::is_void_v<InVisit<Walker, Expression*>>,
+ "Walker::inVisit must return void. Modification is forbidden between walking "
+ "children.");
+ auto skippingFirst = true;
+ for (auto&& child : expression->getChildren()) {
+ if (skippingFirst)
+ skippingFirst = false;
+ else
+ walker->inVisit(expression);
+ if (auto newChild = walk(walker, child.get()))
+ child = std::move(newChild);
+ }
+ }
+ // If the signature of InVisit includes a count, maintaing it while walking and pass it to
+ // the function.
+ else if constexpr (hasCountingInVisit<Walker>) {
+ static_assert(
+ std::is_void_v<InVisit<Walker, unsigned long long, Expression*>>,
+ "Walker::inVisit must return void. Modification is forbidden between walking "
+ "children.");
+ auto count = 0ull;
+ for (auto&& child : expression->getChildren()) {
+ if (count != 0ull)
+ walker->inVisit(count, expression);
+ count++;
+ walkChild(child);
+ }
+ } else {
+ for (auto&& child : expression->getChildren())
+ walkChild(child);
}
- walker->postVisit(expression);
+ if constexpr (hasVoidPostVisit<Walker>)
+ walker->postVisit(expression);
+ else if constexpr (hasPtrPostVisit<Walker>)
+ if (auto postResult = walker->postVisit(expression))
+ newExpr = std::move(postResult);
}
+ if constexpr (hasReturningVisit<Walker>)
+ return newExpr;
}
} // namespace mongo::expression_walker
diff --git a/src/mongo/db/pipeline/expression_walker_test.cpp b/src/mongo/db/pipeline/expression_walker_test.cpp
index 447ae1e3c9a..c3a59dd82e4 100644
--- a/src/mongo/db/pipeline/expression_walker_test.cpp
+++ b/src/mongo/db/pipeline/expression_walker_test.cpp
@@ -29,11 +29,14 @@
#include "mongo/platform/basic.h"
+#include <memory>
#include <string>
+#include <type_traits>
#include <vector>
#include "mongo/base/string_data.h"
#include "mongo/bson/json.h"
+#include "mongo/db/exec/document_value/document_value_test_util.h"
#include "mongo/db/pipeline/aggregate_command_gen.h"
#include "mongo/db/pipeline/aggregation_context_fixture.h"
#include "mongo/db/pipeline/aggregation_request_helper.h"
@@ -52,9 +55,9 @@ protected:
ASSERT_EQUALS(inputBson["pipeline"].type(), BSONType::Array);
auto rawPipeline = parsePipelineFromBSON(inputBson["pipeline"]);
NamespaceString testNss("test", "collection");
- AggregateCommand request(testNss, rawPipeline);
+ auto command = AggregateCommand{testNss, rawPipeline};
- return Pipeline::parse(request.getPipeline(), getExpCtx());
+ return Pipeline::parse(command.getPipeline(), getExpCtx());
}
auto parseExpression(std::string expressionString) {
@@ -66,13 +69,12 @@ protected:
using namespace std::string_literals;
using namespace expression_walker;
-TEST_F(ExpressionWalkerTest, NullTreeWalkSucceeds) {
+TEST_F(ExpressionWalkerTest, NothingTreeWalkSucceedsAndReturnsVoid) {
struct {
- void preVisit(Expression*) {}
- void inVisit(unsigned long long, Expression*) {}
void postVisit(Expression*) {}
} nothingWalker;
- auto expression = boost::intrusive_ptr<Expression>();
+ auto expression = std::unique_ptr<Expression>{};
+ static_assert(std::is_same_v<decltype(walk(&nothingWalker, expression.get())), void>);
walk(&nothingWalker, expression.get());
}
@@ -99,15 +101,47 @@ TEST_F(ExpressionWalkerTest, PrintWalkReflectsMutation) {
auto expression = parseExpression(expressionString);
walk(&stringWalker, expression.get());
ASSERT_EQ(stringWalker.string, expressionString);
+
+ struct {
+ auto preVisit(Expression* expression) {
+ if (auto constant = dynamic_cast<ExpressionConstant*>(expression))
+ if (constant->getValue().getString() == "black")
+ return std::make_unique<ExpressionConstant>(expCtx, Value{"white"s});
+ return std::unique_ptr<ExpressionConstant>{};
+ }
+ ExpressionContext* const expCtx;
+ } whiteWalker{getExpCtxRaw()};
+
+ ASSERT_FALSE(walk(&whiteWalker, expression.get()));
+ stringWalker.string.clear();
+ walk(&stringWalker, expression.get());
+ ASSERT_EQ(stringWalker.string, "{$concat: [\"white\", \"green\", \"yellow\"]}"s);
+}
+
+TEST_F(ExpressionWalkerTest, RootNodeReplacable) {
+ struct {
+ auto postVisit(Expression* expression) {
+ return std::make_unique<ExpressionConstant>(expCtx, Value{"soup"s});
+ }
+ ExpressionContext* const expCtx;
+ } replaceWithSoup{getExpCtxRaw()};
+
+ auto expressionString = "{$add: [2, 3, 4, {$atan2: [1, 0]}]}"s;
+ auto expression = parseExpression(expressionString);
+ auto resultExpression = walk(&replaceWithSoup, expression.get());
+ ASSERT_VALUE_EQ(dynamic_cast<ExpressionConstant*>(resultExpression.get())->getValue(),
+ Value{"soup"s});
+ // The input Expression, as a side effect, will have all its branches changed to soup by this
+ // rewrite.
+ for (auto&& child : dynamic_cast<ExpressionAdd*>(expression.get())->getChildren())
+ ASSERT_VALUE_EQ(dynamic_cast<ExpressionConstant*>(child.get())->getValue(), Value{"soup"s});
}
TEST_F(ExpressionWalkerTest, InVisitCanCount) {
struct {
- void preVisit(Expression*) {}
void inVisit(unsigned long long count, Expression*) {
counter.push_back(count);
}
- void postVisit(Expression*) {}
std::vector<unsigned long long> counter;
} countWalker;