From 4a7131637aa7f0d7132e35938b51a58b5d3529d1 Mon Sep 17 00:00:00 2001 From: Svilen Mihaylov Date: Wed, 9 Sep 2020 17:01:28 -0400 Subject: clone --- src/mongo/db/query/optimizer/node.cpp | 111 +++++++++++++++++++++--- src/mongo/db/query/optimizer/node.h | 22 +++-- src/mongo/db/query/optimizer/optimizer_test.cpp | 3 + 3 files changed, 116 insertions(+), 20 deletions(-) diff --git a/src/mongo/db/query/optimizer/node.cpp b/src/mongo/db/query/optimizer/node.cpp index d101f60b018..4836dcce39e 100644 --- a/src/mongo/db/query/optimizer/node.cpp +++ b/src/mongo/db/query/optimizer/node.cpp @@ -27,8 +27,12 @@ * it in the license file. */ +#include +#include + #include "mongo/db/query/optimizer/node.h" #include "mongo/db/query/optimizer/visitor.h" +#include "mongo/util/assert_util.h" namespace mongo::optimizer { @@ -38,7 +42,7 @@ Node::Node(Context& ctx, NodePtr child) : _nodeId(ctx.getNextNodeId()) { _children.push_back(std::move(child)); } -Node::Node(Context& ctx, std::vector children) +Node::Node(Context& ctx, ChildVector children) : _nodeId(ctx.getNextNodeId()), _children(std::move(children)) {} void Node::generateMemoBase(std::ostringstream& os) const { @@ -52,41 +56,102 @@ void Node::visitPreOrder(AbstractVisitor& visitor) const { } } +void Node::visitPostOrder(AbstractVisitor& visitor) const { + for (const NodePtr& ptr : _children) { + ptr->visitPostOrder(visitor); + } + visit(visitor); +} + std::string Node::generateMemo() const { class MemoVisitor : public AbstractVisitor { protected: void visit(const ScanNode& node) override { - node.generateMemo(os); + node.generateMemo(_os); } void visit(const MultiJoinNode& node) override { - node.generateMemo(os); + node.generateMemo(_os); } void visit(const UnionNode& node) override { - node.generateMemo(os); + node.generateMemo(_os); } void visit(const GroupByNode& node) override { - node.generateMemo(os); + node.generateMemo(_os); } void visit(const UnwindNode& node) override { - node.generateMemo(os); + node.generateMemo(_os); } void visit(const WindNode& node) override { - node.generateMemo(os); + node.generateMemo(_os); } public: - std::ostringstream os; + std::ostringstream _os; }; MemoVisitor visitor; visitPreOrder(visitor); - return visitor.os.str(); + return visitor._os.str(); +} + +NodePtr Node::clone(Context& ctx) const { + class CloneVisitor : public AbstractVisitor { + public: + explicit CloneVisitor(Context& ctx) : _ctx(ctx), _childStack() {} + + protected: + void visit(const ScanNode& node) override { + doClone(node, [&](ChildVector v){ return ScanNode::clone(_ctx, node); }); + } + void visit(const MultiJoinNode& node) override { + doClone(node, [&](ChildVector v){ return MultiJoinNode::clone(_ctx, node, std::move(v)); }); + } + void visit(const UnionNode& node) override { + doClone(node, [&](ChildVector v){ return UnionNode::clone(_ctx, node, std::move(v)); }); + } + void visit(const GroupByNode& node) override { + doClone(node, [&](ChildVector v){ return GroupByNode::clone(_ctx, node, std::move(v.at(0))); }); + } + void visit(const UnwindNode& node) override { + doClone(node, [&](ChildVector v){ return UnwindNode::clone(_ctx, node, std::move(v.at(0))); }); + } + void visit(const WindNode& node) override { + doClone(node, [&](ChildVector v){ return WindNode::clone(_ctx, node, std::move(v.at(0))); }); + } + + private: + void doClone(const Node& node, const std::function& cloneFn) { + ChildVector newChildren; + for (int i = 0; i < node.getChildCount(); i++) { + newChildren.push_back(std::move(_childStack.top())); + _childStack.pop(); + } + _childStack.push(cloneFn(std::move(newChildren))); + } + + public: + Context& _ctx; + std::stack _childStack; + }; + + CloneVisitor visitor(ctx); + visitPostOrder(visitor); + invariant(visitor._childStack.size() == 1); + return std::move(visitor._childStack.top()); +} + +int Node::getChildCount() const { + return _children.size(); } NodePtr ScanNode::create(Context& ctx, CollectionNameType collectionName) { return NodePtr(new ScanNode(ctx, std::move(collectionName))); } +NodePtr ScanNode::clone(Context& ctx, const ScanNode& other) { + return create(ctx, other._collectionName); +} + ScanNode::ScanNode(Context& ctx, CollectionNameType collectionName) : Node(ctx), _collectionName(std::move(collectionName)) {} @@ -103,15 +168,19 @@ void ScanNode::visit(AbstractVisitor& visitor) const { NodePtr MultiJoinNode::create(Context& ctx, FilterSet filterSet, ProjectionMap projectionMap, - std::vector children) { + ChildVector children) { return NodePtr(new MultiJoinNode( ctx, std::move(filterSet), std::move(projectionMap), std::move(children))); } +NodePtr MultiJoinNode::clone(Context& ctx, const MultiJoinNode& other, ChildVector newChildren) { + return create(ctx, other._filterSet, other._projectionMap, std::move(newChildren)); +} + MultiJoinNode::MultiJoinNode(Context& ctx, FilterSet filterSet, ProjectionMap projectionMap, - std::vector children) + ChildVector children) : Node(ctx, std::move(children)), _filterSet(std::move(filterSet)), _projectionMap(std::move(projectionMap)) {} @@ -126,11 +195,15 @@ void MultiJoinNode::visit(AbstractVisitor& visitor) const { visitor.visit(*this); } -NodePtr UnionNode::create(Context& ctx, std::vector children) { +NodePtr UnionNode::create(Context& ctx, ChildVector children) { return NodePtr(new UnionNode(ctx, std::move(children))); } -UnionNode::UnionNode(Context& ctx, std::vector children) +NodePtr UnionNode::clone(Context& ctx, const UnionNode& other, ChildVector newChildren) { + return create(ctx, std::move(newChildren)); +} + +UnionNode::UnionNode(Context& ctx, ChildVector children) : Node(ctx, std::move(children)) {} void UnionNode::generateMemo(std::ostringstream& os) const { @@ -151,6 +224,10 @@ NodePtr GroupByNode::create(Context& ctx, new GroupByNode(ctx, std::move(groupByVector), std::move(projectionMap), std::move(child))); } +NodePtr GroupByNode::clone(Context& ctx, const GroupByNode& other, NodePtr newChild) { + return create(ctx, other._groupByVector, other._projectionMap, std::move(newChild)); +} + GroupByNode::GroupByNode(Context& ctx, GroupByNode::GroupByVector groupByVector, GroupByNode::ProjectionMap projectionMap, @@ -177,6 +254,10 @@ NodePtr UnwindNode::create(Context& ctx, new UnwindNode(ctx, std::move(projectionName), retainNonArrays, std::move(child))); } +NodePtr UnwindNode::clone(Context& ctx, const UnwindNode& other, NodePtr newChild) { + return create(ctx, other._projectionName, other._retainNonArrays, std::move(newChild)); +} + UnwindNode::UnwindNode(Context& ctx, ProjectionName projectionName, const bool retainNonArrays, @@ -199,6 +280,10 @@ NodePtr WindNode::create(Context& ctx, ProjectionName projectionName, NodePtr ch return NodePtr(new WindNode(ctx, std::move(projectionName), std::move(child))); } +NodePtr WindNode::clone(Context& ctx, const WindNode& other, NodePtr newChild) { + return create(ctx, other._projectionName, std::move(newChild)); +} + WindNode::WindNode(Context& ctx, ProjectionName projectionName, NodePtr child) : Node(ctx, std::move(child)), _projectionName(std::move(projectionName)) {} diff --git a/src/mongo/db/query/optimizer/node.h b/src/mongo/db/query/optimizer/node.h index fba45edab17..78010d7d333 100644 --- a/src/mongo/db/query/optimizer/node.h +++ b/src/mongo/db/query/optimizer/node.h @@ -57,13 +57,13 @@ public: protected: explicit Node(Context& ctx); explicit Node(Context& ctx, NodePtr child); - explicit Node(Context& ctx, std::vector children); + explicit Node(Context& ctx, ChildVector children); void generateMemoBase(std::ostringstream& os) const; virtual void visit(AbstractVisitor& visitor) const = 0; - void visitPreOrder(AbstractVisitor& visitor) const; + void visitPostOrder(AbstractVisitor& visitor) const; // clone public: @@ -71,7 +71,9 @@ public: std::string generateMemo() const; - //NodePtr clone(Context& ctx) const; + NodePtr clone(Context& ctx) const; + + int getChildCount() const; private: const NodeIdType _nodeId; @@ -81,6 +83,7 @@ private: class ScanNode : public Node { public: static NodePtr create(Context& ctx, CollectionNameType collectionName); + static NodePtr clone(Context& ctx, const ScanNode& other); void generateMemo(std::ostringstream& os) const; @@ -101,7 +104,8 @@ public: static NodePtr create(Context& ctx, FilterSet filterSet, ProjectionMap projectionMap, - std::vector children); + ChildVector children); + static NodePtr clone(Context& ctx, const MultiJoinNode& other, ChildVector newChildren); void generateMemo(std::ostringstream& os) const; @@ -112,7 +116,7 @@ private: explicit MultiJoinNode(Context& ctx, FilterSet filterSet, ProjectionMap projectionMap, - std::vector children); + ChildVector children); FilterSet _filterSet; ProjectionMap _projectionMap; @@ -120,7 +124,8 @@ private: class UnionNode : public Node { public: - static NodePtr create(Context& ctx, std::vector children); + static NodePtr create(Context& ctx, ChildVector children); + static NodePtr clone(Context& ctx, const UnionNode& other, ChildVector newChildren); void generateMemo(std::ostringstream& os) const; @@ -128,7 +133,7 @@ protected: void visit(AbstractVisitor& visitor) const override; private: - explicit UnionNode(Context& ctx, std::vector children); + explicit UnionNode(Context& ctx, ChildVector children); }; class GroupByNode : public Node { @@ -140,6 +145,7 @@ public: GroupByVector groupByVector, ProjectionMap projectionMap, NodePtr child); + static NodePtr clone(Context& ctx, const GroupByNode& other, NodePtr newChild); void generateMemo(std::ostringstream& os) const; @@ -162,6 +168,7 @@ public: ProjectionName projectionName, bool retainNonArrays, NodePtr child); + static NodePtr clone(Context& ctx, const UnwindNode& other, NodePtr newChild); void generateMemo(std::ostringstream& os) const; @@ -178,6 +185,7 @@ private: class WindNode : public Node { public: static NodePtr create(Context& ctx, ProjectionName projectionName, NodePtr child); + static NodePtr clone(Context& ctx, const WindNode& other, NodePtr newChild); void generateMemo(std::ostringstream& os) const; diff --git a/src/mongo/db/query/optimizer/optimizer_test.cpp b/src/mongo/db/query/optimizer/optimizer_test.cpp index f7eb881c531..86966e05a7e 100644 --- a/src/mongo/db/query/optimizer/optimizer_test.cpp +++ b/src/mongo/db/query/optimizer/optimizer_test.cpp @@ -41,6 +41,9 @@ TEST(Optimizer, Basic) { v.push_back(std::move(ptrScan)); NodePtr ptrJoin = MultiJoinNode::create(ctx, {}, {}, std::move(v)); ASSERT_EQ("NodeId: 1\nMultiJoin\nNodeId: 0\nScan\n", ptrJoin->generateMemo()); + + NodePtr cloned = ptrJoin->clone(ctx); + ASSERT_EQ("NodeId: 3\nMultiJoin\nNodeId: 2\nScan\n", cloned->generateMemo()); } } // namespace -- cgit v1.2.1