diff options
-rw-r--r-- | src/mongo/db/query/optimizer/cascades/memo.h | 3 | ||||
-rw-r--r-- | src/mongo/db/query/optimizer/cascades/physical_rewriter.cpp | 3 | ||||
-rw-r--r-- | src/mongo/db/query/optimizer/explain.cpp | 292 | ||||
-rw-r--r-- | src/mongo/db/query/optimizer/explain.h | 2 | ||||
-rw-r--r-- | src/mongo/db/query/optimizer/physical_rewriter_optimizer_test.cpp | 23 |
5 files changed, 220 insertions, 103 deletions
diff --git a/src/mongo/db/query/optimizer/cascades/memo.h b/src/mongo/db/query/optimizer/cascades/memo.h index 40933e59d4d..a2a9124561f 100644 --- a/src/mongo/db/query/optimizer/cascades/memo.h +++ b/src/mongo/db/query/optimizer/cascades/memo.h @@ -86,6 +86,9 @@ struct PhysNodeInfo { // Rule that triggered the creation of this node. PhysicalRewriteType _rule; + + // Node-specific cardinality estimates, for explain. + NodeCEMap _nodeCEMap; }; struct PhysOptimizationResult { diff --git a/src/mongo/db/query/optimizer/cascades/physical_rewriter.cpp b/src/mongo/db/query/optimizer/cascades/physical_rewriter.cpp index 4d7208d50cb..6402101d507 100644 --- a/src/mongo/db/query/optimizer/cascades/physical_rewriter.cpp +++ b/src/mongo/db/query/optimizer/cascades/physical_rewriter.cpp @@ -175,7 +175,8 @@ void PhysicalRewriter::costAndRetainBestNode(std::unique_ptr<ABT> node, tassert(6678300, "Retaining node with uninitialized rewrite rule", rule != cascades::PhysicalRewriteType::Uninitialized); - PhysNodeInfo candidateNodeInfo{std::move(*node), cost, nodeCost, nodeCostAndCE._ce, rule}; + PhysNodeInfo candidateNodeInfo{ + std::move(*node), cost, nodeCost, nodeCostAndCE._ce, rule, std::move(nodeCEMap)}; const bool keepRejectedPlans = _hints._keepRejectedPlans; if (improvement) { if (keepRejectedPlans && bestResult._nodeInfo) { diff --git a/src/mongo/db/query/optimizer/explain.cpp b/src/mongo/db/query/optimizer/explain.cpp index 5796de6f03a..3a24f2a8a43 100644 --- a/src/mongo/db/query/optimizer/explain.cpp +++ b/src/mongo/db/query/optimizer/explain.cpp @@ -559,8 +559,12 @@ public: ExplainGeneratorTransporter(bool displayProperties = false, const cascades::Memo* memo = nullptr, - const NodeToGroupPropsMap& nodeMap = {}) - : _displayProperties(displayProperties), _memo(memo), _nodeMap(nodeMap) { + const NodeToGroupPropsMap& nodeMap = {}, + const boost::optional<const NodeCEMap&>& nodeCEMap = boost::none) + : _displayProperties(displayProperties), + _memo(memo), + _nodeMap(nodeMap), + _nodeCEMap(nodeCEMap) { uassert(6624005, "Memo must be provided in order to display properties.", !_displayProperties || (_memo != nullptr || version == ExplainVersion::V3)); @@ -572,7 +576,11 @@ public: * no-op. */ void maybePrintProps(ExplainPrinter& nodePrinter, const Node& node) { - if (!_displayProperties || version != ExplainVersion::V3 || _nodeMap.empty()) { + tassert(6701800, + "Cannot have both _displayProperties and _nodeCEMap set.", + !(_displayProperties && _nodeCEMap)); + if (_nodeCEMap || !_displayProperties || version != ExplainVersion::V3 || + _nodeMap.empty()) { return; } auto it = _nodeMap.find(&node); @@ -601,6 +609,25 @@ public: nodePrinter.printAppend(res); } + void nodeCEPropsPrint(ExplainPrinter& nodePrinter, const ABT& n, const Node& node) { + tassert(6701801, + "Cannot have both _displayProperties and _nodeCEMap set.", + !(_displayProperties && _nodeCEMap)); + // Only allow in V2 and V3 explain. No point in printing CE when we have a delegator + // node. + if (!_nodeCEMap || version == ExplainVersion::V1 || n.is<MemoLogicalDelegatorNode>() || + n.is<MemoPhysicalDelegatorNode>()) { + return; + } + auto it = _nodeCEMap->find(&node); + uassert(6701802, "Failed to find node ce", it != _nodeCEMap->end()); + const CEType ce = it->second; + + ExplainPrinter propsPrinter; + propsPrinter.fieldName("ce").print(ce); + nodePrinter.printAppend(propsPrinter); + } + static void printBooleanFlag(ExplainPrinter& printer, const std::string& name, const bool flag, @@ -634,13 +661,16 @@ public: /** * Nodes */ - ExplainPrinter transport(const References& references, std::vector<ExplainPrinter> inResults) { + ExplainPrinter transport(const ABT& /*n*/, + const References& references, + std::vector<ExplainPrinter> inResults) { ExplainPrinter printer; printer.separator("RefBlock: ").printAppend(inResults); return printer; } - ExplainPrinter transport(const ExpressionBinder& binders, + ExplainPrinter transport(const ABT& /*n*/, + const ExpressionBinder& binders, std::vector<ExplainPrinter> inResults) { std::map<std::string, ExplainPrinter> ordered; for (size_t idx = 0; idx < inResults.size(); ++idx) { @@ -698,37 +728,36 @@ public: } } - ExplainPrinter transport(const ScanNode& node, ExplainPrinter bindResult) { + ExplainPrinter transport(const ABT& n, const ScanNode& node, ExplainPrinter bindResult) { ExplainPrinter printer("Scan"); maybePrintProps(printer, node); - printer.separator(" [") .fieldName("scanDefName", ExplainVersion::V3) .print(node.getScanDefName()) - .separator("]") - .fieldName("bindings", ExplainVersion::V3) - .print(bindResult); + .separator("]"); + nodeCEPropsPrint(printer, n, node); + printer.fieldName("bindings", ExplainVersion::V3).print(bindResult); return printer; } - ExplainPrinter transport(const PhysicalScanNode& node, ExplainPrinter bindResult) { + ExplainPrinter transport(const ABT& n, + const PhysicalScanNode& node, + ExplainPrinter bindResult) { ExplainPrinter printer("PhysicalScan"); maybePrintProps(printer, node); - printer.separator(" [{"); printFieldProjectionMap(printer, node.getFieldProjectionMap()); printer.separator("}, ") .fieldName("scanDefName", ExplainVersion::V3) .print(node.getScanDefName()); - printBooleanFlag(printer, "parallel", node.useParallelScan()); - - printer.separator("]").fieldName("bindings", ExplainVersion::V3).print(bindResult); - + printer.separator("]"); + nodeCEPropsPrint(printer, n, node); + printer.fieldName("bindings", ExplainVersion::V3).print(bindResult); return printer; } - ExplainPrinter transport(const ValueScanNode& node, ExplainPrinter bindResult) { + ExplainPrinter transport(const ABT& n, const ValueScanNode& node, ExplainPrinter bindResult) { ExplainPrinter valuePrinter = generate(node.getValueArray()); // Specifically not printing optional logical properties here. They can be displayed with @@ -741,18 +770,20 @@ public: .print(node.getHasRID()) .fieldName("arraySize") .print(node.getArraySize()) - .separator("]") - .fieldName("values", ExplainVersion::V3) + .separator("]"); + nodeCEPropsPrint(printer, n, node); + printer.fieldName("values", ExplainVersion::V3) .print(valuePrinter) .fieldName("bindings", ExplainVersion::V3) .print(bindResult); return printer; } - ExplainPrinter transport(const CoScanNode& node) { + ExplainPrinter transport(const ABT& n, const CoScanNode& node) { ExplainPrinter printer("CoScan"); maybePrintProps(printer, node); printer.separator(" []"); + nodeCEPropsPrint(printer, n, node); return printer; } @@ -900,10 +931,9 @@ public: ExplainGeneratorTransporter& _instance; }; - ExplainPrinter transport(const IndexScanNode& node, ExplainPrinter bindResult) { + ExplainPrinter transport(const ABT& n, const IndexScanNode& node, ExplainPrinter bindResult) { ExplainPrinter printer("IndexScan"); maybePrintProps(printer, node); - printer.separator(" [{"); printFieldProjectionMap(printer, node.getFieldProjectionMap()); printer.separator("}, "); @@ -922,16 +952,18 @@ public: printBooleanFlag(printer, "reversed", spec.isReverseOrder()); - printer.separator("]").fieldName("bindings", ExplainVersion::V3).print(bindResult); + printer.separator("]"); + nodeCEPropsPrint(printer, n, node); + printer.fieldName("bindings", ExplainVersion::V3).print(bindResult); return printer; } - ExplainPrinter transport(const SeekNode& node, + ExplainPrinter transport(const ABT& n, + const SeekNode& node, ExplainPrinter bindResult, ExplainPrinter refsResult) { ExplainPrinter printer("Seek"); maybePrintProps(printer, node); - printer.separator(" [") .fieldName("ridProjection") .print(node.getRIDProjectionName()) @@ -940,8 +972,10 @@ public: printer.separator("}, ") .fieldName("scanDefName", ExplainVersion::V3) .print(node.getScanDefName()) - .separator("]") - .setChildCount(2) + .separator("]"); + nodeCEPropsPrint(printer, n, node); + + printer.setChildCount(2) .fieldName("bindings", ExplainVersion::V3) .print(bindResult) .fieldName("references", ExplainVersion::V3) @@ -950,14 +984,15 @@ public: return printer; } - ExplainPrinter transport(const MemoLogicalDelegatorNode& node) { + ExplainPrinter transport(const ABT& n, const MemoLogicalDelegatorNode& node) { ExplainPrinter printer("MemoLogicalDelegator"); maybePrintProps(printer, node); printer.separator(" [").fieldName("groupId").print(node.getGroupId()).separator("]"); + nodeCEPropsPrint(printer, n, node); return printer; } - ExplainPrinter transport(const MemoPhysicalDelegatorNode& node) { + ExplainPrinter transport(const ABT& /*n*/, const MemoPhysicalDelegatorNode& node) { const auto id = node.getNodeId(); if (_displayProperties) { @@ -1011,13 +1046,15 @@ public: return printer; } - ExplainPrinter transport(const FilterNode& node, + ExplainPrinter transport(const ABT& n, + const FilterNode& node, ExplainPrinter childResult, ExplainPrinter filterResult) { ExplainPrinter printer("Filter"); maybePrintProps(printer, node); - printer.separator(" []") - .setChildCount(2) + printer.separator(" []"); + nodeCEPropsPrint(printer, n, node); + printer.setChildCount(2) .fieldName("filter", ExplainVersion::V3) .print(filterResult) .fieldName("child", ExplainVersion::V3) @@ -1025,13 +1062,15 @@ public: return printer; } - ExplainPrinter transport(const EvaluationNode& node, + ExplainPrinter transport(const ABT& n, + const EvaluationNode& node, ExplainPrinter childResult, ExplainPrinter projectionResult) { ExplainPrinter printer("Evaluation"); maybePrintProps(printer, node); - printer.separator(" []") - .setChildCount(2) + printer.separator(" []"); + nodeCEPropsPrint(printer, n, node); + printer.setChildCount(2) .fieldName("projection", ExplainVersion::V3) .print(projectionResult) .fieldName("child", ExplainVersion::V3) @@ -1097,7 +1136,8 @@ public: parent.fieldName("residualReqs").print(printers); } - ExplainPrinter transport(const SargableNode& node, + ExplainPrinter transport(const ABT& n, + const SargableNode& node, ExplainPrinter childResult, ExplainPrinter bindResult, ExplainPrinter refsResult) { @@ -1108,8 +1148,9 @@ public: printer.separator(" [") .fieldName("target", ExplainVersion::V3) .print(IndexReqTargetEnum::toString[static_cast<int>(node.getTarget())]) - .separator("]") - .setChildCount(scanParams ? 6 : 5); + .separator("]"); + nodeCEPropsPrint(printer, n, node); + printer.setChildCount(scanParams ? 6 : 5); if constexpr (version < ExplainVersion::V3) { ExplainPrinter local; @@ -1229,20 +1270,21 @@ public: return printer; } - ExplainPrinter transport(const RIDIntersectNode& node, + ExplainPrinter transport(const ABT& n, + const RIDIntersectNode& node, ExplainPrinter leftChildResult, ExplainPrinter rightChildResult) { ExplainPrinter printer("RIDIntersect"); maybePrintProps(printer, node); - printer.separator(" [") .fieldName("scanProjectionName", ExplainVersion::V3) .print(node.getScanProjectionName()); printBooleanFlag(printer, "hasLeftIntervals", node.hasLeftIntervals()); printBooleanFlag(printer, "hasRightIntervals", node.hasRightIntervals()); - printer.separator("]") - .setChildCount(2) + printer.separator("]"); + nodeCEPropsPrint(printer, n, node); + printer.setChildCount(2) .maybeReverse() .fieldName("leftChild", ExplainVersion::V3) .print(leftChildResult) @@ -1251,13 +1293,13 @@ public: return printer; } - ExplainPrinter transport(const BinaryJoinNode& node, + ExplainPrinter transport(const ABT& n, + const BinaryJoinNode& node, ExplainPrinter leftChildResult, ExplainPrinter rightChildResult, ExplainPrinter filterResult) { ExplainPrinter printer("BinaryJoin"); maybePrintProps(printer, node); - printer.separator(" [") .fieldName("joinType") .print(JoinTypeEnum::toString[static_cast<int>(node.getJoinType())]); @@ -1293,8 +1335,9 @@ public: MONGO_UNREACHABLE; } - printer.separator("]") - .setChildCount(3) + printer.separator("]"); + nodeCEPropsPrint(printer, n, node); + printer.setChildCount(3) .fieldName("expression", ExplainVersion::V3) .print(filterResult) .maybeReverse() @@ -1331,17 +1374,18 @@ public: } } - ExplainPrinter transport(const HashJoinNode& node, + ExplainPrinter transport(const ABT& n, + const HashJoinNode& node, ExplainPrinter leftChildResult, ExplainPrinter rightChildResult, ExplainPrinter /*refsResult*/) { ExplainPrinter printer("HashJoin"); maybePrintProps(printer, node); - printer.separator(" [") .fieldName("joinType") .print(JoinTypeEnum::toString[static_cast<int>(node.getJoinType())]) .separator("]"); + nodeCEPropsPrint(printer, n, node); ExplainPrinter joinConditionPrinter; printEqualityJoinCondition(joinConditionPrinter, node.getLeftKeys(), node.getRightKeys()); @@ -1357,13 +1401,15 @@ public: return printer; } - ExplainPrinter transport(const MergeJoinNode& node, + ExplainPrinter transport(const ABT& n, + const MergeJoinNode& node, ExplainPrinter leftChildResult, ExplainPrinter rightChildResult, ExplainPrinter /*refsResult*/) { ExplainPrinter printer("MergeJoin"); maybePrintProps(printer, node); printer.separator(" []"); + nodeCEPropsPrint(printer, n, node); ExplainPrinter joinConditionPrinter; printEqualityJoinCondition(joinConditionPrinter, node.getLeftKeys(), node.getRightKeys()); @@ -1401,14 +1447,16 @@ public: return printer; } - ExplainPrinter transport(const UnionNode& node, + ExplainPrinter transport(const ABT& n, + const UnionNode& node, std::vector<ExplainPrinter> childResults, ExplainPrinter bindResult, ExplainPrinter /*refsResult*/) { ExplainPrinter printer("Union"); maybePrintProps(printer, node); - printer.separator(" []") - .setChildCount(childResults.size() + 1) + printer.separator(" []"); + nodeCEPropsPrint(printer, n, node); + printer.setChildCount(childResults.size() + 1) .fieldName("bindings", ExplainVersion::V3) .print(bindResult) .maybeReverse() @@ -1417,7 +1465,8 @@ public: return printer; } - ExplainPrinter transport(const GroupByNode& node, + ExplainPrinter transport(const ABT& n, + const GroupByNode& node, ExplainPrinter childResult, ExplainPrinter bindAggResult, ExplainPrinter refsAggResult, @@ -1437,6 +1486,7 @@ public: .print(GroupNodeTypeEnum::toString[static_cast<int>(node.getType())]); } printer.separator("]"); + nodeCEPropsPrint(printer, n, node); std::vector<ExplainPrinter> aggPrinters; for (const auto& [projectionName, index] : ordered) { @@ -1464,16 +1514,17 @@ public: return printer; } - ExplainPrinter transport(const UnwindNode& node, + ExplainPrinter transport(const ABT& n, + const UnwindNode& node, ExplainPrinter childResult, ExplainPrinter bindResult, ExplainPrinter refsResult) { ExplainPrinter printer("Unwind"); maybePrintProps(printer, node); - printer.separator(" ["); printBooleanFlag(printer, "retainNonArrays", node.getRetainNonArrays(), false /*addComma*/); printer.separator("]"); + nodeCEPropsPrint(printer, n, node); printer.setChildCount(2) .fieldName("bind", ExplainVersion::V3) @@ -1502,26 +1553,32 @@ public: }); } - ExplainPrinter transport(const UniqueNode& node, + ExplainPrinter transport(const ABT& n, + const UniqueNode& node, ExplainPrinter childResult, ExplainPrinter /*refsResult*/) { ExplainPrinter printer("Unique"); maybePrintProps(printer, node); + printer.separator(" []"); + nodeCEPropsPrint(printer, n, node); - printer.separator(" []").setChildCount(2); + printer.setChildCount(2); printPropertyProjections(printer, node.getProjections(), false /*directToParent*/); printer.fieldName("child", ExplainVersion::V3).print(childResult); return printer; } - ExplainPrinter transport(const CollationNode& node, + ExplainPrinter transport(const ABT& n, + const CollationNode& node, ExplainPrinter childResult, ExplainPrinter refsResult) { ExplainPrinter printer("Collation"); maybePrintProps(printer, node); + printer.separator(" []"); + nodeCEPropsPrint(printer, n, node); - printer.separator(" []").setChildCount(3); + printer.setChildCount(3); printCollationProperty(printer, node.getProperty(), false /*directToParent*/); printer.fieldName("references", ExplainVersion::V3) .print(refsResult) @@ -1561,11 +1618,13 @@ public: }); } - ExplainPrinter transport(const LimitSkipNode& node, ExplainPrinter childResult) { + ExplainPrinter transport(const ABT& n, const LimitSkipNode& node, ExplainPrinter childResult) { ExplainPrinter printer("LimitSkip"); maybePrintProps(printer, node); + printer.separator(" []"); + nodeCEPropsPrint(printer, n, node); - printer.separator(" []").setChildCount(2); + printer.setChildCount(2); printLimitSkipProperty(printer, node.getProperty(), false /*directToParent*/); printer.fieldName("child", ExplainVersion::V3).print(childResult); @@ -1624,13 +1683,16 @@ public: printPropertyProjections(parent, property.getProjections().getVector(), directToParent); } - ExplainPrinter transport(const ExchangeNode& node, + ExplainPrinter transport(const ABT& n, + const ExchangeNode& node, ExplainPrinter childResult, ExplainPrinter refsResult) { ExplainPrinter printer("Exchange"); maybePrintProps(printer, node); + printer.separator(" []"); + nodeCEPropsPrint(printer, n, node); - printer.separator(" []").setChildCount(3); + printer.setChildCount(3); printDistributionProperty(printer, node.getProperty(), false /*directToParent*/); printer.fieldName("references", ExplainVersion::V3) .print(refsResult) @@ -1868,13 +1930,16 @@ public: return printProps<properties::PhysProperty, PhysPropPrintVisitor>(description, props); } - ExplainPrinter transport(const RootNode& node, + ExplainPrinter transport(const ABT& n, + const RootNode& node, ExplainPrinter childResult, ExplainPrinter refsResult) { ExplainPrinter printer("Root"); maybePrintProps(printer, node); + printer.separator(" []"); + nodeCEPropsPrint(printer, n, node); - printer.separator(" []").setChildCount(3); + printer.setChildCount(3); printProjectionRequirementProperty(printer, node.getProperty(), false /*directToParent*/); printer.fieldName("references", ExplainVersion::V3) .print(refsResult) @@ -1887,13 +1952,13 @@ public: /** * Expressions */ - ExplainPrinter transport(const Blackhole& expr) { + ExplainPrinter transport(const ABT& /*n*/, const Blackhole& expr) { ExplainPrinter printer("Blackhole"); printer.separator(" []"); return printer; } - ExplainPrinter transport(const Constant& expr) { + ExplainPrinter transport(const ABT& /*n*/, const Constant& expr) { ExplainPrinter printer("Const"); printer.separator(" [").fieldName("tag", ExplainVersion::V3); @@ -1908,7 +1973,7 @@ public: return printer; } - ExplainPrinter transport(const Variable& expr) { + ExplainPrinter transport(const ABT& /*n*/, const Variable& expr) { ExplainPrinter printer("Variable"); printer.separator(" [") .fieldName("name", ExplainVersion::V3) @@ -1917,7 +1982,7 @@ public: return printer; } - ExplainPrinter transport(const UnaryOp& expr, ExplainPrinter inResult) { + ExplainPrinter transport(const ABT& /*n*/, const UnaryOp& expr, ExplainPrinter inResult) { ExplainPrinter printer("UnaryOp"); printer.separator(" [") .fieldName("op", ExplainVersion::V3) @@ -1929,7 +1994,8 @@ public: return printer; } - ExplainPrinter transport(const BinaryOp& expr, + ExplainPrinter transport(const ABT& /*n*/, + const BinaryOp& expr, ExplainPrinter leftResult, ExplainPrinter rightResult) { ExplainPrinter printer("BinaryOp"); @@ -1947,7 +2013,8 @@ public: } - ExplainPrinter transport(const If& expr, + ExplainPrinter transport(const ABT& /*n*/, + const If& expr, ExplainPrinter condResult, ExplainPrinter thenResult, ExplainPrinter elseResult) { @@ -1964,7 +2031,8 @@ public: return printer; } - ExplainPrinter transport(const Let& expr, + ExplainPrinter transport(const ABT& /*n*/, + const Let& expr, ExplainPrinter bindResult, ExplainPrinter exprResult) { ExplainPrinter printer("Let"); @@ -1981,7 +2049,9 @@ public: return printer; } - ExplainPrinter transport(const LambdaAbstraction& expr, ExplainPrinter inResult) { + ExplainPrinter transport(const ABT& /*n*/, + const LambdaAbstraction& expr, + ExplainPrinter inResult) { ExplainPrinter printer("LambdaAbstraction"); printer.separator(" [") .fieldName("variable", ExplainVersion::V3) @@ -1993,7 +2063,8 @@ public: return printer; } - ExplainPrinter transport(const LambdaApplication& expr, + ExplainPrinter transport(const ABT& /*n*/, + const LambdaApplication& expr, ExplainPrinter lambdaResult, ExplainPrinter argumentResult) { ExplainPrinter printer("LambdaApplication"); @@ -2007,7 +2078,9 @@ public: return printer; } - ExplainPrinter transport(const FunctionCall& expr, std::vector<ExplainPrinter> argResults) { + ExplainPrinter transport(const ABT& /*n*/, + const FunctionCall& expr, + std::vector<ExplainPrinter> argResults) { ExplainPrinter printer("FunctionCall"); printer.separator(" [") .fieldName("name", ExplainVersion::V3) @@ -2022,7 +2095,8 @@ public: return printer; } - ExplainPrinter transport(const EvalPath& expr, + ExplainPrinter transport(const ABT& /*n*/, + const EvalPath& expr, ExplainPrinter pathResult, ExplainPrinter inputResult) { ExplainPrinter printer("EvalPath"); @@ -2036,7 +2110,8 @@ public: return printer; } - ExplainPrinter transport(const EvalFilter& expr, + ExplainPrinter transport(const ABT& /*n*/, + const EvalFilter& expr, ExplainPrinter pathResult, ExplainPrinter inputResult) { ExplainPrinter printer("EvalFilter"); @@ -2053,7 +2128,7 @@ public: /** * Paths */ - ExplainPrinter transport(const PathConstant& path, ExplainPrinter inResult) { + ExplainPrinter transport(const ABT& /*n*/, const PathConstant& path, ExplainPrinter inResult) { ExplainPrinter printer("PathConstant"); printer.separator(" []") .setChildCount(1) @@ -2062,7 +2137,7 @@ public: return printer; } - ExplainPrinter transport(const PathLambda& path, ExplainPrinter inResult) { + ExplainPrinter transport(const ABT& /*n*/, const PathLambda& path, ExplainPrinter inResult) { ExplainPrinter printer("PathLambda"); printer.separator(" []") .setChildCount(1) @@ -2071,13 +2146,13 @@ public: return printer; } - ExplainPrinter transport(const PathIdentity& path) { + ExplainPrinter transport(const ABT& /*n*/, const PathIdentity& path) { ExplainPrinter printer("PathIdentity"); printer.separator(" []"); return printer; } - ExplainPrinter transport(const PathDefault& path, ExplainPrinter inResult) { + ExplainPrinter transport(const ABT& /*n*/, const PathDefault& path, ExplainPrinter inResult) { ExplainPrinter printer("PathDefault"); printer.separator(" []") .setChildCount(1) @@ -2086,7 +2161,9 @@ public: return printer; } - ExplainPrinter transport(const PathCompare& path, ExplainPrinter valueResult) { + ExplainPrinter transport(const ABT& /*n*/, + const PathCompare& path, + ExplainPrinter valueResult) { ExplainPrinter printer("PathCompare"); printer.separator(" [") .fieldName("op", ExplainVersion::V3) @@ -2122,7 +2199,7 @@ public: } } - ExplainPrinter transport(const PathDrop& path) { + ExplainPrinter transport(const ABT& /*n*/, const PathDrop& path) { ExplainPrinter printer("PathDrop"); printer.separator(" ["); printPathProjections(printer, path.getNames()); @@ -2130,7 +2207,7 @@ public: return printer; } - ExplainPrinter transport(const PathKeep& path) { + ExplainPrinter transport(const ABT& /*n*/, const PathKeep& path) { ExplainPrinter printer("PathKeep"); printer.separator(" ["); printPathProjections(printer, path.getNames()); @@ -2138,19 +2215,19 @@ public: return printer; } - ExplainPrinter transport(const PathObj& path) { + ExplainPrinter transport(const ABT& /*n*/, const PathObj& path) { ExplainPrinter printer("PathObj"); printer.separator(" []"); return printer; } - ExplainPrinter transport(const PathArr& path) { + ExplainPrinter transport(const ABT& /*n*/, const PathArr& path) { ExplainPrinter printer("PathArr"); printer.separator(" []"); return printer; } - ExplainPrinter transport(const PathTraverse& path, ExplainPrinter inResult) { + ExplainPrinter transport(const ABT& /*n*/, const PathTraverse& path, ExplainPrinter inResult) { ExplainPrinter printer("PathTraverse"); printer.separator(" ["); @@ -2173,7 +2250,7 @@ public: return printer; } - ExplainPrinter transport(const PathField& path, ExplainPrinter inResult) { + ExplainPrinter transport(const ABT& /*n*/, const PathField& path, ExplainPrinter inResult) { ExplainPrinter printer("PathField"); printer.separator(" [") .fieldName("path", ExplainVersion::V3) @@ -2185,7 +2262,7 @@ public: return printer; } - ExplainPrinter transport(const PathGet& path, ExplainPrinter inResult) { + ExplainPrinter transport(const ABT& /*n*/, const PathGet& path, ExplainPrinter inResult) { ExplainPrinter printer("PathGet"); printer.separator(" [") .fieldName("path", ExplainVersion::V3) @@ -2197,7 +2274,8 @@ public: return printer; } - ExplainPrinter transport(const PathComposeM& path, + ExplainPrinter transport(const ABT& /*n*/, + const PathComposeM& path, ExplainPrinter leftResult, ExplainPrinter rightResult) { ExplainPrinter printer("PathComposeM"); @@ -2211,7 +2289,8 @@ public: return printer; } - ExplainPrinter transport(const PathComposeA& path, + ExplainPrinter transport(const ABT& /*n*/, + const PathComposeA& path, ExplainPrinter leftResult, ExplainPrinter rightResult) { ExplainPrinter printer("PathComposeA"); @@ -2225,14 +2304,14 @@ public: return printer; } - ExplainPrinter transport(const Source& expr) { + ExplainPrinter transport(const ABT& /*n*/, const Source& expr) { ExplainPrinter printer("Source"); printer.separator(" []"); return printer; } ExplainPrinter generate(const ABT& node) { - return algebra::transport<false>(node, *this); + return algebra::transport<true>(node, *this); } void printPhysNodeInfo(ExplainPrinter& printer, const cascades::PhysNodeInfo& nodeInfo) { @@ -2249,7 +2328,9 @@ public: .fieldName("adjustedCE") .print(nodeInfo._adjustedCE); - ExplainPrinter nodePrinter = generate(nodeInfo._node); + ExplainGeneratorTransporter<version> subGen( + _displayProperties, _memo, _nodeMap, nodeInfo._nodeCEMap); + ExplainPrinter nodePrinter = subGen.generate(nodeInfo._node); printer.separator(", ").fieldName("node").print(nodePrinter); } @@ -2354,6 +2435,7 @@ private: // We don't own this. const cascades::Memo* _memo; const NodeToGroupPropsMap& _nodeMap; + boost::optional<const NodeCEMap&> _nodeCEMap; }; std::string ExplainGenerator::explain(const ABT& node, @@ -2396,18 +2478,20 @@ std::pair<sbe::value::TypeTags, sbe::value::Value> ExplainGenerator::explainBSON return gen.generate(node).moveValue(); } +BSONObj convertSbeValToBSONObj(const std::pair<sbe::value::TypeTags, sbe::value::Value> val) { + uassert(6624070, "Expected an object", val.first == sbe::value::TypeTags::Object); + sbe::value::ValueGuard vg(val.first, val.second); + + BSONObjBuilder builder; + sbe::bson::convertToBsonObj(builder, sbe::value::getObjectView(val.second)); + return builder.done().getOwned(); +} + BSONObj ExplainGenerator::explainBSONObj(const ABT& node, const bool displayProperties, const cascades::Memo* memo, const NodeToGroupPropsMap& nodeMap) { - auto [tag, val] = - optimizer::ExplainGenerator::explainBSON(node, displayProperties, memo, nodeMap); - uassert(6624070, "Expected an object", tag == sbe::value::TypeTags::Object); - sbe::value::ValueGuard vg(tag, val); - - BSONObjBuilder builder; - sbe::bson::convertToBsonObj(builder, sbe::value::getObjectView(val)); - return builder.done().getOwned(); + return convertSbeValToBSONObj(explainBSON(node, displayProperties, memo, nodeMap)); } template <class PrinterType> @@ -2488,6 +2572,10 @@ std::pair<sbe::value::TypeTags, sbe::value::Value> ExplainGenerator::explainMemo return gen.printMemo().moveValue(); } +BSONObj ExplainGenerator::explainMemoBSONObj(const cascades::Memo& memo) { + return convertSbeValToBSONObj(explainMemoBSON(memo)); +} + std::string ExplainGenerator::explainPartialSchemaReqMap(const PartialSchemaRequirements& reqMap) { ExplainGeneratorTransporter<ExplainVersion::V2> gen; ExplainGeneratorTransporter<ExplainVersion::V2>::ExplainPrinter result; diff --git a/src/mongo/db/query/optimizer/explain.h b/src/mongo/db/query/optimizer/explain.h index dfc8cdda77d..522d09af5e9 100644 --- a/src/mongo/db/query/optimizer/explain.h +++ b/src/mongo/db/query/optimizer/explain.h @@ -107,6 +107,8 @@ public: static std::pair<sbe::value::TypeTags, sbe::value::Value> explainMemoBSON( const cascades::Memo& memo); + static BSONObj explainMemoBSONObj(const cascades::Memo& memo); + static std::string explainPartialSchemaReqMap(const PartialSchemaRequirements& reqMap); static std::string explainInterval(const IntervalRequirement& interval); diff --git a/src/mongo/db/query/optimizer/physical_rewriter_optimizer_test.cpp b/src/mongo/db/query/optimizer/physical_rewriter_optimizer_test.cpp index a98155964e4..4b8e605bcab 100644 --- a/src/mongo/db/query/optimizer/physical_rewriter_optimizer_test.cpp +++ b/src/mongo/db/query/optimizer/physical_rewriter_optimizer_test.cpp @@ -352,6 +352,17 @@ TEST(PhysRewriter, GroupBy1) { phaseManager.optimize(optimized); ASSERT_EQ(5, phaseManager.getMemo().getStats()._physPlanExplorationCount); + // Assert we have specific CE details at certain nodes. + std::vector<std::string> cePaths = {"Memo.0.physicalNodes.0.nodeInfo.node.ce", + "Memo.1.physicalNodes.0.nodeInfo.node.ce", + "Memo.3.physicalNodes.0.nodeInfo.node.ce", + "Memo.4.physicalNodes.0.nodeInfo.node.ce"}; + BSONObj bsonMemo = ExplainGenerator::explainMemoBSONObj(phaseManager.getMemo()); + for (const auto& cePath : cePaths) { + BSONElement ce = dotted_path_support::extractElementAtPath(bsonMemo, cePath); + ASSERT(!ce.eoo()); + } + // Projection "pb1" is unused and we do not generate an aggregation expression for it. ASSERT_EXPLAIN_V2( "Root []\n" @@ -1296,6 +1307,18 @@ TEST(PhysRewriter, FilterIndexing4) { phaseManager.optimize(optimized); ASSERT_BETWEEN(60, 75, phaseManager.getMemo().getStats()._physPlanExplorationCount); + // Assert the correct CEs for each node in group 1. Group 1 contains residual predicates. + std::vector<std::pair<std::string, double>> pathAndCEs = { + {"Memo.1.physicalNodes.1.nodeInfo.node.ce", 143.6810174757394}, + {"Memo.1.physicalNodes.1.nodeInfo.node.child.ce", 189.57056733575502}, + {"Memo.1.physicalNodes.1.nodeInfo.node.child.child.ce", 330.00000000000006}, + {"Memo.1.physicalNodes.1.nodeInfo.node.child.child.child.ce", 330.00000000000006}}; + const BSONObj explain = ExplainGenerator::explainMemoBSONObj(phaseManager.getMemo()); + for (const auto& pathAndCE : pathAndCEs) { + BSONElement el = dotted_path_support::extractElementAtPath(explain, pathAndCE.first); + ASSERT_EQ(el.Double(), pathAndCE.second); + } + ASSERT_EXPLAIN_V2( "Root []\n" "| | projections: \n" |