summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSvilen Mihaylov <svilen.mihaylov@mongodb.com>2020-09-09 17:01:28 -0400
committerSvilen Mihaylov <svilen.mihaylov@mongodb.com>2020-09-11 10:11:33 -0400
commit6dee6f0a7b62faaad5a43a4564a76e4ab99b0012 (patch)
tree5349f999fd86f7fd6445f48b1a5b57ca3489ff36
parentb7b7cd9f449a5eecddde3120c57425652531e3d8 (diff)
downloadmongo-6dee6f0a7b62faaad5a43a4564a76e4ab99b0012.tar.gz
clone
-rw-r--r--src/mongo/db/query/optimizer/node.cpp111
-rw-r--r--src/mongo/db/query/optimizer/node.h22
-rw-r--r--src/mongo/db/query/optimizer/optimizer_test.cpp3
3 files changed, 116 insertions, 20 deletions
diff --git a/src/mongo/db/query/optimizer/node.cpp b/src/mongo/db/query/optimizer/node.cpp
index d101f60b018..4836dcce39e 100644
--- a/src/mongo/db/query/optimizer/node.cpp
+++ b/src/mongo/db/query/optimizer/node.cpp
@@ -27,8 +27,12 @@
* it in the license file.
*/
+#include <functional>
+#include <stack>
+
#include "mongo/db/query/optimizer/node.h"
#include "mongo/db/query/optimizer/visitor.h"
+#include "mongo/util/assert_util.h"
namespace mongo::optimizer {
@@ -38,7 +42,7 @@ Node::Node(Context& ctx, NodePtr child) : _nodeId(ctx.getNextNodeId()) {
_children.push_back(std::move(child));
}
-Node::Node(Context& ctx, std::vector<NodePtr> children)
+Node::Node(Context& ctx, ChildVector children)
: _nodeId(ctx.getNextNodeId()), _children(std::move(children)) {}
void Node::generateMemoBase(std::ostringstream& os) const {
@@ -52,41 +56,102 @@ void Node::visitPreOrder(AbstractVisitor& visitor) const {
}
}
+void Node::visitPostOrder(AbstractVisitor& visitor) const {
+ for (const NodePtr& ptr : _children) {
+ ptr->visitPostOrder(visitor);
+ }
+ visit(visitor);
+}
+
std::string Node::generateMemo() const {
class MemoVisitor : public AbstractVisitor {
protected:
void visit(const ScanNode& node) override {
- node.generateMemo(os);
+ node.generateMemo(_os);
}
void visit(const MultiJoinNode& node) override {
- node.generateMemo(os);
+ node.generateMemo(_os);
}
void visit(const UnionNode& node) override {
- node.generateMemo(os);
+ node.generateMemo(_os);
}
void visit(const GroupByNode& node) override {
- node.generateMemo(os);
+ node.generateMemo(_os);
}
void visit(const UnwindNode& node) override {
- node.generateMemo(os);
+ node.generateMemo(_os);
}
void visit(const WindNode& node) override {
- node.generateMemo(os);
+ node.generateMemo(_os);
}
public:
- std::ostringstream os;
+ std::ostringstream _os;
};
MemoVisitor visitor;
visitPreOrder(visitor);
- return visitor.os.str();
+ return visitor._os.str();
+}
+
+NodePtr Node::clone(Context& ctx) const {
+ class CloneVisitor : public AbstractVisitor {
+ public:
+ explicit CloneVisitor(Context& ctx) : _ctx(ctx), _childStack() {}
+
+ protected:
+ void visit(const ScanNode& node) override {
+ doClone(node, [&](ChildVector v){ return ScanNode::clone(_ctx, node); });
+ }
+ void visit(const MultiJoinNode& node) override {
+ doClone(node, [&](ChildVector v){ return MultiJoinNode::clone(_ctx, node, std::move(v)); });
+ }
+ void visit(const UnionNode& node) override {
+ doClone(node, [&](ChildVector v){ return UnionNode::clone(_ctx, node, std::move(v)); });
+ }
+ void visit(const GroupByNode& node) override {
+ doClone(node, [&](ChildVector v){ return GroupByNode::clone(_ctx, node, std::move(v.at(0))); });
+ }
+ void visit(const UnwindNode& node) override {
+ doClone(node, [&](ChildVector v){ return UnwindNode::clone(_ctx, node, std::move(v.at(0))); });
+ }
+ void visit(const WindNode& node) override {
+ doClone(node, [&](ChildVector v){ return WindNode::clone(_ctx, node, std::move(v.at(0))); });
+ }
+
+ private:
+ void doClone(const Node& node, const std::function<NodePtr(ChildVector newChildren)>& cloneFn) {
+ ChildVector newChildren;
+ for (int i = 0; i < node.getChildCount(); i++) {
+ newChildren.push_back(std::move(_childStack.top()));
+ _childStack.pop();
+ }
+ _childStack.push(cloneFn(std::move(newChildren)));
+ }
+
+ public:
+ Context& _ctx;
+ std::stack<NodePtr> _childStack;
+ };
+
+ CloneVisitor visitor(ctx);
+ visitPostOrder(visitor);
+ invariant(visitor._childStack.size() == 1);
+ return std::move(visitor._childStack.top());
+}
+
+int Node::getChildCount() const {
+ return _children.size();
}
NodePtr ScanNode::create(Context& ctx, CollectionNameType collectionName) {
return NodePtr(new ScanNode(ctx, std::move(collectionName)));
}
+NodePtr ScanNode::clone(Context& ctx, const ScanNode& other) {
+ return create(ctx, other._collectionName);
+}
+
ScanNode::ScanNode(Context& ctx, CollectionNameType collectionName)
: Node(ctx), _collectionName(std::move(collectionName)) {}
@@ -103,15 +168,19 @@ void ScanNode::visit(AbstractVisitor& visitor) const {
NodePtr MultiJoinNode::create(Context& ctx,
FilterSet filterSet,
ProjectionMap projectionMap,
- std::vector<NodePtr> children) {
+ ChildVector children) {
return NodePtr(new MultiJoinNode(
ctx, std::move(filterSet), std::move(projectionMap), std::move(children)));
}
+NodePtr MultiJoinNode::clone(Context& ctx, const MultiJoinNode& other, ChildVector newChildren) {
+ return create(ctx, other._filterSet, other._projectionMap, std::move(newChildren));
+}
+
MultiJoinNode::MultiJoinNode(Context& ctx,
FilterSet filterSet,
ProjectionMap projectionMap,
- std::vector<NodePtr> children)
+ ChildVector children)
: Node(ctx, std::move(children)),
_filterSet(std::move(filterSet)),
_projectionMap(std::move(projectionMap)) {}
@@ -126,11 +195,15 @@ void MultiJoinNode::visit(AbstractVisitor& visitor) const {
visitor.visit(*this);
}
-NodePtr UnionNode::create(Context& ctx, std::vector<NodePtr> children) {
+NodePtr UnionNode::create(Context& ctx, ChildVector children) {
return NodePtr(new UnionNode(ctx, std::move(children)));
}
-UnionNode::UnionNode(Context& ctx, std::vector<NodePtr> children)
+NodePtr UnionNode::clone(Context& ctx, const UnionNode& other, ChildVector newChildren) {
+ return create(ctx, std::move(newChildren));
+}
+
+UnionNode::UnionNode(Context& ctx, ChildVector children)
: Node(ctx, std::move(children)) {}
void UnionNode::generateMemo(std::ostringstream& os) const {
@@ -151,6 +224,10 @@ NodePtr GroupByNode::create(Context& ctx,
new GroupByNode(ctx, std::move(groupByVector), std::move(projectionMap), std::move(child)));
}
+NodePtr GroupByNode::clone(Context& ctx, const GroupByNode& other, NodePtr newChild) {
+ return create(ctx, other._groupByVector, other._projectionMap, std::move(newChild));
+}
+
GroupByNode::GroupByNode(Context& ctx,
GroupByNode::GroupByVector groupByVector,
GroupByNode::ProjectionMap projectionMap,
@@ -177,6 +254,10 @@ NodePtr UnwindNode::create(Context& ctx,
new UnwindNode(ctx, std::move(projectionName), retainNonArrays, std::move(child)));
}
+NodePtr UnwindNode::clone(Context& ctx, const UnwindNode& other, NodePtr newChild) {
+ return create(ctx, other._projectionName, other._retainNonArrays, std::move(newChild));
+}
+
UnwindNode::UnwindNode(Context& ctx,
ProjectionName projectionName,
const bool retainNonArrays,
@@ -199,6 +280,10 @@ NodePtr WindNode::create(Context& ctx, ProjectionName projectionName, NodePtr ch
return NodePtr(new WindNode(ctx, std::move(projectionName), std::move(child)));
}
+NodePtr WindNode::clone(Context& ctx, const WindNode& other, NodePtr newChild) {
+ return create(ctx, other._projectionName, std::move(newChild));
+}
+
WindNode::WindNode(Context& ctx, ProjectionName projectionName, NodePtr child)
: Node(ctx, std::move(child)), _projectionName(std::move(projectionName)) {}
diff --git a/src/mongo/db/query/optimizer/node.h b/src/mongo/db/query/optimizer/node.h
index fba45edab17..78010d7d333 100644
--- a/src/mongo/db/query/optimizer/node.h
+++ b/src/mongo/db/query/optimizer/node.h
@@ -57,13 +57,13 @@ public:
protected:
explicit Node(Context& ctx);
explicit Node(Context& ctx, NodePtr child);
- explicit Node(Context& ctx, std::vector<NodePtr> children);
+ explicit Node(Context& ctx, ChildVector children);
void generateMemoBase(std::ostringstream& os) const;
virtual void visit(AbstractVisitor& visitor) const = 0;
-
void visitPreOrder(AbstractVisitor& visitor) const;
+ void visitPostOrder(AbstractVisitor& visitor) const;
// clone
public:
@@ -71,7 +71,9 @@ public:
std::string generateMemo() const;
- //NodePtr clone(Context& ctx) const;
+ NodePtr clone(Context& ctx) const;
+
+ int getChildCount() const;
private:
const NodeIdType _nodeId;
@@ -81,6 +83,7 @@ private:
class ScanNode : public Node {
public:
static NodePtr create(Context& ctx, CollectionNameType collectionName);
+ static NodePtr clone(Context& ctx, const ScanNode& other);
void generateMemo(std::ostringstream& os) const;
@@ -101,7 +104,8 @@ public:
static NodePtr create(Context& ctx,
FilterSet filterSet,
ProjectionMap projectionMap,
- std::vector<NodePtr> children);
+ ChildVector children);
+ static NodePtr clone(Context& ctx, const MultiJoinNode& other, ChildVector newChildren);
void generateMemo(std::ostringstream& os) const;
@@ -112,7 +116,7 @@ private:
explicit MultiJoinNode(Context& ctx,
FilterSet filterSet,
ProjectionMap projectionMap,
- std::vector<NodePtr> children);
+ ChildVector children);
FilterSet _filterSet;
ProjectionMap _projectionMap;
@@ -120,7 +124,8 @@ private:
class UnionNode : public Node {
public:
- static NodePtr create(Context& ctx, std::vector<NodePtr> children);
+ static NodePtr create(Context& ctx, ChildVector children);
+ static NodePtr clone(Context& ctx, const UnionNode& other, ChildVector newChildren);
void generateMemo(std::ostringstream& os) const;
@@ -128,7 +133,7 @@ protected:
void visit(AbstractVisitor& visitor) const override;
private:
- explicit UnionNode(Context& ctx, std::vector<NodePtr> children);
+ explicit UnionNode(Context& ctx, ChildVector children);
};
class GroupByNode : public Node {
@@ -140,6 +145,7 @@ public:
GroupByVector groupByVector,
ProjectionMap projectionMap,
NodePtr child);
+ static NodePtr clone(Context& ctx, const GroupByNode& other, NodePtr newChild);
void generateMemo(std::ostringstream& os) const;
@@ -162,6 +168,7 @@ public:
ProjectionName projectionName,
bool retainNonArrays,
NodePtr child);
+ static NodePtr clone(Context& ctx, const UnwindNode& other, NodePtr newChild);
void generateMemo(std::ostringstream& os) const;
@@ -178,6 +185,7 @@ private:
class WindNode : public Node {
public:
static NodePtr create(Context& ctx, ProjectionName projectionName, NodePtr child);
+ static NodePtr clone(Context& ctx, const WindNode& other, NodePtr newChild);
void generateMemo(std::ostringstream& os) const;
diff --git a/src/mongo/db/query/optimizer/optimizer_test.cpp b/src/mongo/db/query/optimizer/optimizer_test.cpp
index f7eb881c531..86966e05a7e 100644
--- a/src/mongo/db/query/optimizer/optimizer_test.cpp
+++ b/src/mongo/db/query/optimizer/optimizer_test.cpp
@@ -41,6 +41,9 @@ TEST(Optimizer, Basic) {
v.push_back(std::move(ptrScan));
NodePtr ptrJoin = MultiJoinNode::create(ctx, {}, {}, std::move(v));
ASSERT_EQ("NodeId: 1\nMultiJoin\nNodeId: 0\nScan\n", ptrJoin->generateMemo());
+
+ NodePtr cloned = ptrJoin->clone(ctx);
+ ASSERT_EQ("NodeId: 3\nMultiJoin\nNodeId: 2\nScan\n", cloned->generateMemo());
}
} // namespace