diff options
author | Svilen Mihaylov <svilen.mihaylov@mongodb.com> | 2020-09-09 15:48:21 -0400 |
---|---|---|
committer | Svilen Mihaylov <svilen.mihaylov@mongodb.com> | 2020-09-11 10:11:33 -0400 |
commit | b7b7cd9f449a5eecddde3120c57425652531e3d8 (patch) | |
tree | 5848c061998d0bb91c0d30e00ba22a68384bbb27 | |
parent | 07ee2c7f5ae5ad672f560c851d5625e884cdf3d9 (diff) | |
download | mongo-b7b7cd9f449a5eecddde3120c57425652531e3d8.tar.gz |
fix visitor
-rw-r--r-- | src/mongo/db/query/optimizer/node.cpp | 74 | ||||
-rw-r--r-- | src/mongo/db/query/optimizer/node.h | 38 | ||||
-rw-r--r-- | src/mongo/db/query/optimizer/optimizer_test.cpp | 7 |
3 files changed, 66 insertions, 53 deletions
diff --git a/src/mongo/db/query/optimizer/node.cpp b/src/mongo/db/query/optimizer/node.cpp index c52d55c1be9..d101f60b018 100644 --- a/src/mongo/db/query/optimizer/node.cpp +++ b/src/mongo/db/query/optimizer/node.cpp @@ -45,28 +45,35 @@ void Node::generateMemoBase(std::ostringstream& os) const { os << "NodeId: " << _nodeId << "\n"; } -std::string Node::generateMemo() { - class MemoVisitor: public AbstractVisitor - { +void Node::visitPreOrder(AbstractVisitor& visitor) const { + visit(visitor); + for (const NodePtr& ptr : _children) { + ptr->visitPreOrder(visitor); + } +} + +std::string Node::generateMemo() const { + class MemoVisitor : public AbstractVisitor { protected: void visit(const ScanNode& node) override { - node.generateScanMemo(os); + node.generateMemo(os); } void visit(const MultiJoinNode& node) override { - node.generateMultiJoinMemo(os); + node.generateMemo(os); } void visit(const UnionNode& node) override { - node.generateUnionMemo(os); + node.generateMemo(os); } void visit(const GroupByNode& node) override { - node.generateGroupByMemo(os); + node.generateMemo(os); } void visit(const UnwindNode& node) override { - node.generateUnwindMemo(os); + node.generateMemo(os); } void visit(const WindNode& node) override { - node.generateWindMemo(os); + node.generateMemo(os); } + public: std::ostringstream os; }; @@ -76,13 +83,6 @@ std::string Node::generateMemo() { return visitor.os.str(); } -void Node::visitPreOrder(AbstractVisitor& visitor) { - visit(visitor); - for (const NodePtr& ptr: _children) { - ptr->visitPreOrder(visitor); - } -} - NodePtr ScanNode::create(Context& ctx, CollectionNameType collectionName) { return NodePtr(new ScanNode(ctx, std::move(collectionName))); } @@ -90,12 +90,13 @@ NodePtr ScanNode::create(Context& ctx, CollectionNameType collectionName) { ScanNode::ScanNode(Context& ctx, CollectionNameType collectionName) : Node(ctx), _collectionName(std::move(collectionName)) {} -void ScanNode::generateScanMemo(std::ostringstream& os) const { +void ScanNode::generateMemo(std::ostringstream& os) const { Node::generateMemoBase(os); - os << "Scan" << "\n"; + os << "Scan" + << "\n"; } -void ScanNode::visit(AbstractVisitor& visitor) { +void ScanNode::visit(AbstractVisitor& visitor) const { visitor.visit(*this); } @@ -115,12 +116,13 @@ MultiJoinNode::MultiJoinNode(Context& ctx, _filterSet(std::move(filterSet)), _projectionMap(std::move(projectionMap)) {} -void MultiJoinNode::generateMultiJoinMemo(std::ostringstream& os) const { +void MultiJoinNode::generateMemo(std::ostringstream& os) const { Node::generateMemoBase(os); - os << "MultiJoin" << "\n"; + os << "MultiJoin" + << "\n"; } -void MultiJoinNode::visit(AbstractVisitor& visitor) { +void MultiJoinNode::visit(AbstractVisitor& visitor) const { visitor.visit(*this); } @@ -131,12 +133,13 @@ NodePtr UnionNode::create(Context& ctx, std::vector<NodePtr> children) { UnionNode::UnionNode(Context& ctx, std::vector<NodePtr> children) : Node(ctx, std::move(children)) {} -void UnionNode::generateUnionMemo(std::ostringstream& os) const { +void UnionNode::generateMemo(std::ostringstream& os) const { Node::generateMemoBase(os); - os << "Union" << "\n"; + os << "Union" + << "\n"; } -void UnionNode::visit(AbstractVisitor& visitor) { +void UnionNode::visit(AbstractVisitor& visitor) const { visitor.visit(*this); } @@ -156,12 +159,13 @@ GroupByNode::GroupByNode(Context& ctx, _groupByVector(std::move(groupByVector)), _projectionMap(std::move(projectionMap)) {} -void GroupByNode::generateGroupByMemo(std::ostringstream& os) const { +void GroupByNode::generateMemo(std::ostringstream& os) const { Node::generateMemoBase(os); - os << "GroupBy" << "\n"; + os << "GroupBy" + << "\n"; } -void GroupByNode::visit(AbstractVisitor& visitor) { +void GroupByNode::visit(AbstractVisitor& visitor) const { visitor.visit(*this); } @@ -181,12 +185,13 @@ UnwindNode::UnwindNode(Context& ctx, _projectionName(std::move(projectionName)), _retainNonArrays(retainNonArrays) {} -void UnwindNode::generateUnwindMemo(std::ostringstream& os) const { +void UnwindNode::generateMemo(std::ostringstream& os) const { Node::generateMemoBase(os); - os << "Unwind" << "\n"; + os << "Unwind" + << "\n"; } -void UnwindNode::visit(AbstractVisitor& visitor) { +void UnwindNode::visit(AbstractVisitor& visitor) const { visitor.visit(*this); } @@ -197,12 +202,13 @@ NodePtr WindNode::create(Context& ctx, ProjectionName projectionName, NodePtr ch WindNode::WindNode(Context& ctx, ProjectionName projectionName, NodePtr child) : Node(ctx, std::move(child)), _projectionName(std::move(projectionName)) {} -void WindNode::generateWindMemo(std::ostringstream& os) const { +void WindNode::generateMemo(std::ostringstream& os) const { Node::generateMemoBase(os); - os << "Wind" << "\n"; + os << "Wind" + << "\n"; } -void WindNode::visit(AbstractVisitor& visitor) { +void WindNode::visit(AbstractVisitor& visitor) const { visitor.visit(*this); } diff --git a/src/mongo/db/query/optimizer/node.h b/src/mongo/db/query/optimizer/node.h index ba7da13bf98..fba45edab17 100644 --- a/src/mongo/db/query/optimizer/node.h +++ b/src/mongo/db/query/optimizer/node.h @@ -51,6 +51,9 @@ using NodePtr = std::unique_ptr<Node>; class AbstractVisitor; class Node { +public: + using ChildVector = std::vector<NodePtr>; + protected: explicit Node(Context& ctx); explicit Node(Context& ctx, NodePtr child); @@ -58,29 +61,31 @@ protected: void generateMemoBase(std::ostringstream& os) const; - virtual void visit(AbstractVisitor& visitor) = 0; + virtual void visit(AbstractVisitor& visitor) const = 0; - virtual void visitPreOrder(AbstractVisitor& visitor) final; + void visitPreOrder(AbstractVisitor& visitor) const; // clone public: Node() = delete; - virtual std::string generateMemo() final; + std::string generateMemo() const; + + //NodePtr clone(Context& ctx) const; private: const NodeIdType _nodeId; - std::vector<NodePtr> _children; + ChildVector _children; }; class ScanNode : public Node { public: static NodePtr create(Context& ctx, CollectionNameType collectionName); - void generateScanMemo(std::ostringstream& os) const; + void generateMemo(std::ostringstream& os) const; protected: - void visit(AbstractVisitor& visitor) override; + void visit(AbstractVisitor& visitor) const override; private: explicit ScanNode(Context& ctx, CollectionNameType collectionName); @@ -98,10 +103,10 @@ public: ProjectionMap projectionMap, std::vector<NodePtr> children); - void generateMultiJoinMemo(std::ostringstream& os) const; + void generateMemo(std::ostringstream& os) const; protected: - void visit(AbstractVisitor& visitor) override; + void visit(AbstractVisitor& visitor) const override; private: explicit MultiJoinNode(Context& ctx, @@ -117,10 +122,10 @@ class UnionNode : public Node { public: static NodePtr create(Context& ctx, std::vector<NodePtr> children); - void generateUnionMemo(std::ostringstream& os) const; + void generateMemo(std::ostringstream& os) const; protected: - void visit(AbstractVisitor& visitor) override; + void visit(AbstractVisitor& visitor) const override; private: explicit UnionNode(Context& ctx, std::vector<NodePtr> children); @@ -136,10 +141,10 @@ public: ProjectionMap projectionMap, NodePtr child); - void generateGroupByMemo(std::ostringstream& os) const; + void generateMemo(std::ostringstream& os) const; protected: - void visit(AbstractVisitor& visitor) override; + void visit(AbstractVisitor& visitor) const override; private: explicit GroupByNode(Context& ctx, @@ -158,11 +163,10 @@ public: bool retainNonArrays, NodePtr child); - void generateUnwindMemo(std::ostringstream& os) const; + void generateMemo(std::ostringstream& os) const; protected: - - void visit(AbstractVisitor& visitor) override; + void visit(AbstractVisitor& visitor) const override; private: UnwindNode(Context& ctx, ProjectionName projectionName, bool retainNonArrays, NodePtr child); @@ -175,10 +179,10 @@ class WindNode : public Node { public: static NodePtr create(Context& ctx, ProjectionName projectionName, NodePtr child); - void generateWindMemo(std::ostringstream& os) const; + void generateMemo(std::ostringstream& os) const; protected: - void visit(AbstractVisitor& visitor) override; + void visit(AbstractVisitor& visitor) const override; private: WindNode(Context& ctx, ProjectionName projectionName, NodePtr child); diff --git a/src/mongo/db/query/optimizer/optimizer_test.cpp b/src/mongo/db/query/optimizer/optimizer_test.cpp index 7eeaf801b16..f7eb881c531 100644 --- a/src/mongo/db/query/optimizer/optimizer_test.cpp +++ b/src/mongo/db/query/optimizer/optimizer_test.cpp @@ -36,8 +36,11 @@ namespace { TEST(Optimizer, Basic) { Context ctx; - NodePtr ptr = ScanNode::create(ctx, "test"); - ASSERT_EQ("NodeId: 0\nScan\n", ptr->generateMemo()); + NodePtr ptrScan = ScanNode::create(ctx, "test"); + Node::ChildVector v; + v.push_back(std::move(ptrScan)); + NodePtr ptrJoin = MultiJoinNode::create(ctx, {}, {}, std::move(v)); + ASSERT_EQ("NodeId: 1\nMultiJoin\nNodeId: 0\nScan\n", ptrJoin->generateMemo()); } } // namespace |