summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSvilen Mihaylov <svilen.mihaylov@mongodb.com>2020-09-09 15:48:21 -0400
committerSvilen Mihaylov <svilen.mihaylov@mongodb.com>2020-09-09 15:53:35 -0400
commit9edbc1726d8b6084b7440c0fd47a2f6a46b0bce6 (patch)
tree8f66017c3ce31bd5a3144e0675ded76e915ab622
parent18af06fa5a43e829035b625eba4dab504a905880 (diff)
downloadmongo-9edbc1726d8b6084b7440c0fd47a2f6a46b0bce6.tar.gz
fix visitor
-rw-r--r--src/mongo/db/query/optimizer/node.cpp74
-rw-r--r--src/mongo/db/query/optimizer/node.h38
-rw-r--r--src/mongo/db/query/optimizer/optimizer_test.cpp7
3 files changed, 65 insertions, 54 deletions
diff --git a/src/mongo/db/query/optimizer/node.cpp b/src/mongo/db/query/optimizer/node.cpp
index c52d55c1be9..d101f60b018 100644
--- a/src/mongo/db/query/optimizer/node.cpp
+++ b/src/mongo/db/query/optimizer/node.cpp
@@ -45,28 +45,35 @@ void Node::generateMemoBase(std::ostringstream& os) const {
os << "NodeId: " << _nodeId << "\n";
}
-std::string Node::generateMemo() {
- class MemoVisitor: public AbstractVisitor
- {
+void Node::visitPreOrder(AbstractVisitor& visitor) const {
+ visit(visitor);
+ for (const NodePtr& ptr : _children) {
+ ptr->visitPreOrder(visitor);
+ }
+}
+
+std::string Node::generateMemo() const {
+ class MemoVisitor : public AbstractVisitor {
protected:
void visit(const ScanNode& node) override {
- node.generateScanMemo(os);
+ node.generateMemo(os);
}
void visit(const MultiJoinNode& node) override {
- node.generateMultiJoinMemo(os);
+ node.generateMemo(os);
}
void visit(const UnionNode& node) override {
- node.generateUnionMemo(os);
+ node.generateMemo(os);
}
void visit(const GroupByNode& node) override {
- node.generateGroupByMemo(os);
+ node.generateMemo(os);
}
void visit(const UnwindNode& node) override {
- node.generateUnwindMemo(os);
+ node.generateMemo(os);
}
void visit(const WindNode& node) override {
- node.generateWindMemo(os);
+ node.generateMemo(os);
}
+
public:
std::ostringstream os;
};
@@ -76,13 +83,6 @@ std::string Node::generateMemo() {
return visitor.os.str();
}
-void Node::visitPreOrder(AbstractVisitor& visitor) {
- visit(visitor);
- for (const NodePtr& ptr: _children) {
- ptr->visitPreOrder(visitor);
- }
-}
-
NodePtr ScanNode::create(Context& ctx, CollectionNameType collectionName) {
return NodePtr(new ScanNode(ctx, std::move(collectionName)));
}
@@ -90,12 +90,13 @@ NodePtr ScanNode::create(Context& ctx, CollectionNameType collectionName) {
ScanNode::ScanNode(Context& ctx, CollectionNameType collectionName)
: Node(ctx), _collectionName(std::move(collectionName)) {}
-void ScanNode::generateScanMemo(std::ostringstream& os) const {
+void ScanNode::generateMemo(std::ostringstream& os) const {
Node::generateMemoBase(os);
- os << "Scan" << "\n";
+ os << "Scan"
+ << "\n";
}
-void ScanNode::visit(AbstractVisitor& visitor) {
+void ScanNode::visit(AbstractVisitor& visitor) const {
visitor.visit(*this);
}
@@ -115,12 +116,13 @@ MultiJoinNode::MultiJoinNode(Context& ctx,
_filterSet(std::move(filterSet)),
_projectionMap(std::move(projectionMap)) {}
-void MultiJoinNode::generateMultiJoinMemo(std::ostringstream& os) const {
+void MultiJoinNode::generateMemo(std::ostringstream& os) const {
Node::generateMemoBase(os);
- os << "MultiJoin" << "\n";
+ os << "MultiJoin"
+ << "\n";
}
-void MultiJoinNode::visit(AbstractVisitor& visitor) {
+void MultiJoinNode::visit(AbstractVisitor& visitor) const {
visitor.visit(*this);
}
@@ -131,12 +133,13 @@ NodePtr UnionNode::create(Context& ctx, std::vector<NodePtr> children) {
UnionNode::UnionNode(Context& ctx, std::vector<NodePtr> children)
: Node(ctx, std::move(children)) {}
-void UnionNode::generateUnionMemo(std::ostringstream& os) const {
+void UnionNode::generateMemo(std::ostringstream& os) const {
Node::generateMemoBase(os);
- os << "Union" << "\n";
+ os << "Union"
+ << "\n";
}
-void UnionNode::visit(AbstractVisitor& visitor) {
+void UnionNode::visit(AbstractVisitor& visitor) const {
visitor.visit(*this);
}
@@ -156,12 +159,13 @@ GroupByNode::GroupByNode(Context& ctx,
_groupByVector(std::move(groupByVector)),
_projectionMap(std::move(projectionMap)) {}
-void GroupByNode::generateGroupByMemo(std::ostringstream& os) const {
+void GroupByNode::generateMemo(std::ostringstream& os) const {
Node::generateMemoBase(os);
- os << "GroupBy" << "\n";
+ os << "GroupBy"
+ << "\n";
}
-void GroupByNode::visit(AbstractVisitor& visitor) {
+void GroupByNode::visit(AbstractVisitor& visitor) const {
visitor.visit(*this);
}
@@ -181,12 +185,13 @@ UnwindNode::UnwindNode(Context& ctx,
_projectionName(std::move(projectionName)),
_retainNonArrays(retainNonArrays) {}
-void UnwindNode::generateUnwindMemo(std::ostringstream& os) const {
+void UnwindNode::generateMemo(std::ostringstream& os) const {
Node::generateMemoBase(os);
- os << "Unwind" << "\n";
+ os << "Unwind"
+ << "\n";
}
-void UnwindNode::visit(AbstractVisitor& visitor) {
+void UnwindNode::visit(AbstractVisitor& visitor) const {
visitor.visit(*this);
}
@@ -197,12 +202,13 @@ NodePtr WindNode::create(Context& ctx, ProjectionName projectionName, NodePtr ch
WindNode::WindNode(Context& ctx, ProjectionName projectionName, NodePtr child)
: Node(ctx, std::move(child)), _projectionName(std::move(projectionName)) {}
-void WindNode::generateWindMemo(std::ostringstream& os) const {
+void WindNode::generateMemo(std::ostringstream& os) const {
Node::generateMemoBase(os);
- os << "Wind" << "\n";
+ os << "Wind"
+ << "\n";
}
-void WindNode::visit(AbstractVisitor& visitor) {
+void WindNode::visit(AbstractVisitor& visitor) const {
visitor.visit(*this);
}
diff --git a/src/mongo/db/query/optimizer/node.h b/src/mongo/db/query/optimizer/node.h
index 917144c81df..fba45edab17 100644
--- a/src/mongo/db/query/optimizer/node.h
+++ b/src/mongo/db/query/optimizer/node.h
@@ -51,7 +51,8 @@ using NodePtr = std::unique_ptr<Node>;
class AbstractVisitor;
class Node {
- friend class AbstractVisitor;
+public:
+ using ChildVector = std::vector<NodePtr>;
protected:
explicit Node(Context& ctx);
@@ -60,29 +61,31 @@ protected:
void generateMemoBase(std::ostringstream& os) const;
- virtual void visit(AbstractVisitor& visitor) = 0;
+ virtual void visit(AbstractVisitor& visitor) const = 0;
- virtual void visitPreOrder(AbstractVisitor& visitor) final;
+ void visitPreOrder(AbstractVisitor& visitor) const;
// clone
public:
Node() = delete;
- virtual std::string generateMemo() final;
+ std::string generateMemo() const;
+
+ //NodePtr clone(Context& ctx) const;
private:
const NodeIdType _nodeId;
- std::vector<NodePtr> _children;
+ ChildVector _children;
};
class ScanNode : public Node {
public:
static NodePtr create(Context& ctx, CollectionNameType collectionName);
- void generateScanMemo(std::ostringstream& os) const;
+ void generateMemo(std::ostringstream& os) const;
protected:
- void visit(AbstractVisitor& visitor) override;
+ void visit(AbstractVisitor& visitor) const override;
private:
explicit ScanNode(Context& ctx, CollectionNameType collectionName);
@@ -100,10 +103,10 @@ public:
ProjectionMap projectionMap,
std::vector<NodePtr> children);
- void generateMultiJoinMemo(std::ostringstream& os) const;
+ void generateMemo(std::ostringstream& os) const;
protected:
- void visit(AbstractVisitor& visitor) override;
+ void visit(AbstractVisitor& visitor) const override;
private:
explicit MultiJoinNode(Context& ctx,
@@ -119,10 +122,10 @@ class UnionNode : public Node {
public:
static NodePtr create(Context& ctx, std::vector<NodePtr> children);
- void generateUnionMemo(std::ostringstream& os) const;
+ void generateMemo(std::ostringstream& os) const;
protected:
- void visit(AbstractVisitor& visitor) override;
+ void visit(AbstractVisitor& visitor) const override;
private:
explicit UnionNode(Context& ctx, std::vector<NodePtr> children);
@@ -138,10 +141,10 @@ public:
ProjectionMap projectionMap,
NodePtr child);
- void generateGroupByMemo(std::ostringstream& os) const;
+ void generateMemo(std::ostringstream& os) const;
protected:
- void visit(AbstractVisitor& visitor) override;
+ void visit(AbstractVisitor& visitor) const override;
private:
explicit GroupByNode(Context& ctx,
@@ -160,11 +163,10 @@ public:
bool retainNonArrays,
NodePtr child);
- void generateUnwindMemo(std::ostringstream& os) const;
+ void generateMemo(std::ostringstream& os) const;
protected:
-
- void visit(AbstractVisitor& visitor) override;
+ void visit(AbstractVisitor& visitor) const override;
private:
UnwindNode(Context& ctx, ProjectionName projectionName, bool retainNonArrays, NodePtr child);
@@ -177,10 +179,10 @@ class WindNode : public Node {
public:
static NodePtr create(Context& ctx, ProjectionName projectionName, NodePtr child);
- void generateWindMemo(std::ostringstream& os) const;
+ void generateMemo(std::ostringstream& os) const;
protected:
- void visit(AbstractVisitor& visitor) override;
+ void visit(AbstractVisitor& visitor) const override;
private:
WindNode(Context& ctx, ProjectionName projectionName, NodePtr child);
diff --git a/src/mongo/db/query/optimizer/optimizer_test.cpp b/src/mongo/db/query/optimizer/optimizer_test.cpp
index 7eeaf801b16..f7eb881c531 100644
--- a/src/mongo/db/query/optimizer/optimizer_test.cpp
+++ b/src/mongo/db/query/optimizer/optimizer_test.cpp
@@ -36,8 +36,11 @@ namespace {
TEST(Optimizer, Basic) {
Context ctx;
- NodePtr ptr = ScanNode::create(ctx, "test");
- ASSERT_EQ("NodeId: 0\nScan\n", ptr->generateMemo());
+ NodePtr ptrScan = ScanNode::create(ctx, "test");
+ Node::ChildVector v;
+ v.push_back(std::move(ptrScan));
+ NodePtr ptrJoin = MultiJoinNode::create(ctx, {}, {}, std::move(v));
+ ASSERT_EQ("NodeId: 1\nMultiJoin\nNodeId: 0\nScan\n", ptrJoin->generateMemo());
}
} // namespace