/** * Copyright (C) 2022-present MongoDB, Inc. * * This program is free software: you can redistribute it and/or modify * it under the terms of the Server Side Public License, version 1, * as published by MongoDB, Inc. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * Server Side Public License for more details. * * You should have received a copy of the Server Side Public License * along with this program. If not, see * . * * As a special exception, the copyright holders give permission to link the * code of portions of this program with the OpenSSL library under certain * conditions as described in each individual source file and distribute * linked combinations including the program with the OpenSSL library. You * must comply with the Server Side Public License in all respects for * all of the code used other than as permitted herein. If you modify file(s) * with this exception, you may extend this exception to your version of the * file(s), but you are not obligated to do so. If you do not wish to do so, * delete this exception statement from your version. If you delete this * exception statement from all source files in the program, then also delete * it in the license file. */ #include "mongo/db/query/optimizer/cascades/logical_rewriter.h" #include "mongo/db/query/optimizer/cascades/rewriter_rules.h" #include "mongo/db/query/optimizer/reference_tracker.h" #include "mongo/db/query/optimizer/utils/reftracker_utils.h" namespace mongo::optimizer::cascades { LogicalRewriter::RewriteSet LogicalRewriter::_explorationSet = { {LogicalRewriteType::GroupByExplore, 1}, {LogicalRewriteType::SargableSplit, 2}, {LogicalRewriteType::FilterRIDIntersectReorder, 2}, {LogicalRewriteType::EvaluationRIDIntersectReorder, 2}}; LogicalRewriter::RewriteSet LogicalRewriter::_substitutionSet = { {LogicalRewriteType::FilterEvaluationReorder, 1}, {LogicalRewriteType::FilterCollationReorder, 1}, {LogicalRewriteType::EvaluationCollationReorder, 1}, {LogicalRewriteType::EvaluationLimitSkipReorder, 1}, {LogicalRewriteType::FilterGroupByReorder, 1}, {LogicalRewriteType::GroupCollationReorder, 1}, {LogicalRewriteType::FilterUnwindReorder, 1}, {LogicalRewriteType::EvaluationUnwindReorder, 1}, {LogicalRewriteType::UnwindCollationReorder, 1}, {LogicalRewriteType::FilterExchangeReorder, 1}, {LogicalRewriteType::ExchangeEvaluationReorder, 1}, {LogicalRewriteType::FilterUnionReorder, 1}, {LogicalRewriteType::CollationMerge, 1}, {LogicalRewriteType::LimitSkipMerge, 1}, {LogicalRewriteType::SargableFilterReorder, 1}, {LogicalRewriteType::SargableEvaluationReorder, 1}, {LogicalRewriteType::FilterValueScanPropagate, 1}, {LogicalRewriteType::EvaluationValueScanPropagate, 1}, {LogicalRewriteType::SargableValueScanPropagate, 1}, {LogicalRewriteType::CollationValueScanPropagate, 1}, {LogicalRewriteType::LimitSkipValueScanPropagate, 1}, {LogicalRewriteType::ExchangeValueScanPropagate, 1}, {LogicalRewriteType::LimitSkipSubstitute, 1}, {LogicalRewriteType::FilterSubstitute, 2}, {LogicalRewriteType::EvaluationSubstitute, 2}, {LogicalRewriteType::SargableMerge, 2}}; LogicalRewriter::LogicalRewriter(const Metadata& metadata, Memo& memo, PrefixId& prefixId, const RewriteSet rewriteSet, const DebugInfo& debugInfo, const QueryHints& hints, const PathToIntervalFn& pathToInterval, const ConstFoldFn& constFold, const LogicalPropsInterface& logicalPropsDerivation, const CEInterface& ceDerivation) : _activeRewriteSet(std::move(rewriteSet)), _groupsPending(), _metadata(metadata), _memo(memo), _prefixId(prefixId), _debugInfo(debugInfo), _hints(hints), _pathToInterval(pathToInterval), _constFold(constFold), _logicalPropsDerivation(logicalPropsDerivation), _ceDerivation(ceDerivation) { initializeRewrites(); if (_activeRewriteSet.count(LogicalRewriteType::SargableSplit) > 0) { // If we are performing SargableSplit exploration rewrite, populate helper map. for (const auto& [scanDefName, scanDef] : _metadata._scanDefs) { for (const auto& [indexDefName, indexDef] : scanDef.getIndexDefs()) { for (const IndexCollationEntry& entry : indexDef.getCollationSpec()) { if (auto pathPtr = entry._path.cast(); pathPtr != nullptr) { _indexFieldPrefixMap[scanDefName].insert(pathPtr->name()); } } } } } } GroupIdType LogicalRewriter::addRootNode(const ABT& node) { return addNode(node, -1, LogicalRewriteType::Root, false /*addExistingNodeWithNewChild*/).first; } std::pair LogicalRewriter::addNode(const ABT& node, const GroupIdType targetGroupId, const LogicalRewriteType rule, const bool addExistingNodeWithNewChild) { NodeIdSet insertNodeIds; Memo::NodeTargetGroupMap targetGroupMap; if (targetGroupId >= 0) { targetGroupMap = {{node.ref(), targetGroupId}}; } const GroupIdType resultGroupId = _memo.integrate( Memo::Context{&_metadata, &_debugInfo, &_logicalPropsDerivation, &_ceDerivation}, node, std::move(targetGroupMap), insertNodeIds, rule, addExistingNodeWithNewChild); uassert(6624046, "Result group is not the same as target group", targetGroupId < 0 || targetGroupId == resultGroupId); for (const MemoLogicalNodeId& nodeMemoId : insertNodeIds) { if (addExistingNodeWithNewChild && nodeMemoId._groupId == targetGroupId) { continue; } for (const auto [type, priority] : _activeRewriteSet) { auto& groupQueue = _memo.getLogicalRewriteQueue(nodeMemoId._groupId); groupQueue.push(std::make_unique(priority, type, nodeMemoId)); _groupsPending.insert(nodeMemoId._groupId); } } return {resultGroupId, std::move(insertNodeIds)}; } void LogicalRewriter::clearGroup(const GroupIdType groupId) { _memo.clearLogicalNodes(groupId); } class RewriteContext { public: RewriteContext(LogicalRewriter& rewriter, const LogicalRewriteType rule, const MemoLogicalNodeId aboveNodeId, const MemoLogicalNodeId belowNodeId) : RewriteContext(rewriter, rule, aboveNodeId, true /*hasBelowNodeId*/, belowNodeId){}; RewriteContext(LogicalRewriter& rewriter, const LogicalRewriteType rule, const MemoLogicalNodeId aboveNodeId) : RewriteContext(rewriter, rule, aboveNodeId, false /*hasBelowNodeId*/, {}){}; std::pair addNode(const ABT& node, const bool substitute, const bool addExistingNodeWithNewChild = false) { if (substitute) { uassert(6624110, "Cannot substitute twice", !_hasSubstituted); _hasSubstituted = true; _rewriter.clearGroup(_aboveNodeId._groupId); if (_hasBelowNodeId) { _rewriter.clearGroup(_belowNodeId._groupId); } } return _rewriter.addNode(node, _aboveNodeId._groupId, _rule, addExistingNodeWithNewChild); } Memo& getMemo() const { return _rewriter._memo; } const Metadata& getMetadata() const { return _rewriter._metadata; } PrefixId& getPrefixId() const { return _rewriter._prefixId; } const QueryHints& getHints() const { return _rewriter._hints; } auto& getIndexFieldPrefixMap() const { return _rewriter._indexFieldPrefixMap; } const properties::LogicalProps& getAboveLogicalProps() const { return getMemo().getLogicalProps(_aboveNodeId._groupId); } bool hasSubstituted() const { return _hasSubstituted; } MemoLogicalNodeId getAboveNodeId() const { return _aboveNodeId; } auto& getSargableSplitCountMap() const { return _rewriter._sargableSplitCountMap; } const auto& getPathToInterval() const { return _rewriter._pathToInterval; } const auto& getConstFold() const { return _rewriter._constFold; } private: RewriteContext(LogicalRewriter& rewriter, const LogicalRewriteType rule, const MemoLogicalNodeId aboveNodeId, const bool hasBelowNodeId, const MemoLogicalNodeId belowNodeId) : _aboveNodeId(aboveNodeId), _hasBelowNodeId(hasBelowNodeId), _belowNodeId(belowNodeId), _rewriter(rewriter), _hasSubstituted(false), _rule(rule){}; const MemoLogicalNodeId _aboveNodeId; const bool _hasBelowNodeId; const MemoLogicalNodeId _belowNodeId; // We don't own this. LogicalRewriter& _rewriter; bool _hasSubstituted; const LogicalRewriteType _rule; }; struct ReorderDependencies { bool _hasNodeRef = false; bool _hasChildRef = false; bool _hasNodeAndChildRef = false; }; template struct DefaultChildAccessor { const ABT& operator()(const ABT& node) const { return node.cast()->getChild(); } ABT& operator()(ABT& node) const { return node.cast()->getChild(); } }; template struct LeftChildAccessor { const ABT& operator()(const ABT& node) const { return node.cast()->getLeftChild(); } ABT& operator()(ABT& node) const { return node.cast()->getLeftChild(); } }; template struct RightChildAccessor { const ABT& operator()(const ABT& node) const { return node.cast()->getRightChild(); } ABT& operator()(ABT& node) const { return node.cast()->getRightChild(); } }; template class BelowChildAccessor = DefaultChildAccessor> ReorderDependencies computeDependencies(ABT::reference_type aboveNodeRef, ABT::reference_type belowNodeRef, RewriteContext& ctx) { // Get variables from above node and check if they are bound at below node, or at below node's // child. const auto aboveNodeVarNames = collectVariableReferences(aboveNodeRef); ABT belowNode = belowNodeRef; VariableEnvironment env = VariableEnvironment::build(belowNode, &ctx.getMemo()); const DefinitionsMap belowNodeDefs = env.hasDefinitions(belowNode.ref()) ? env.getDefinitions(belowNode.ref()) : DefinitionsMap{}; ABT::reference_type belowChild = BelowChildAccessor()(belowNode).ref(); const DefinitionsMap belowChildNodeDefs = env.hasDefinitions(belowChild) ? env.getDefinitions(belowChild) : DefinitionsMap{}; ReorderDependencies dependencies; for (const std::string& varName : aboveNodeVarNames) { auto it = belowNodeDefs.find(varName); // Variable is exclusively defined in the below node. const bool refersToNode = it != belowNodeDefs.cend() && it->second.definedBy == belowNode; // Variable is defined in the belowNode's child subtree. const bool refersToChild = belowChildNodeDefs.find(varName) != belowChildNodeDefs.cend(); if (refersToNode) { if (refersToChild) { dependencies._hasNodeAndChildRef = true; } else { dependencies._hasNodeRef = true; } } else if (refersToChild) { dependencies._hasChildRef = true; } else { // Lambda variable. Ignore. } } return dependencies; } static ABT createEmptyValueScanNode(const RewriteContext& ctx) { using namespace properties; const ProjectionNameSet& projNameSet = getPropertyConst(ctx.getAboveLogicalProps()).getProjections(); ProjectionNameVector projNameVector; projNameVector.insert(projNameVector.begin(), projNameSet.cbegin(), projNameSet.cend()); return make(std::move(projNameVector), ctx.getAboveLogicalProps()); } static void addEmptyValueScanNode(RewriteContext& ctx) { ABT newNode = createEmptyValueScanNode(ctx); ctx.addNode(newNode, true /*substitute*/); } static void defaultPropagateEmptyValueScanNode(const ABT& n, RewriteContext& ctx) { if (n.cast()->getArraySize() == 0) { addEmptyValueScanNode(ctx); } } template class AboveChildAccessor = DefaultChildAccessor, template class BelowChildAccessor = DefaultChildAccessor, bool substitute = true> void defaultReorder(ABT::reference_type aboveNode, ABT::reference_type belowNode, RewriteContext& ctx) { ABT newParent = belowNode; ABT newChild = aboveNode; std::swap(BelowChildAccessor()(newParent), AboveChildAccessor()(newChild)); BelowChildAccessor()(newParent) = std::move(newChild); ctx.addNode(newParent, substitute); } template void defaultReorderWithDependenceCheck(ABT::reference_type aboveNode, ABT::reference_type belowNode, RewriteContext& ctx) { const ReorderDependencies dependencies = computeDependencies(aboveNode, belowNode, ctx); if (dependencies._hasNodeRef) { // Above node refers to a variable bound by below node. return; } defaultReorder(aboveNode, belowNode, ctx); } template struct SubstituteReorder { void operator()(ABT::reference_type aboveNode, ABT::reference_type belowNode, RewriteContext& ctx) const { defaultReorderWithDependenceCheck(aboveNode, belowNode, ctx); } }; template <> struct SubstituteReorder { void operator()(ABT::reference_type aboveNode, ABT::reference_type belowNode, RewriteContext& ctx) const { defaultReorder(aboveNode, belowNode, ctx); } }; template <> struct SubstituteReorder { void operator()(ABT::reference_type aboveNode, ABT::reference_type belowNode, RewriteContext& ctx) const { ABT newParent = belowNode; for (auto& childOfChild : newParent.cast()->nodes()) { ABT aboveCopy = aboveNode; std::swap(aboveCopy.cast()->getChild(), childOfChild); std::swap(childOfChild, aboveCopy); } ctx.addNode(newParent, true /*substitute*/); } }; template void unwindBelowReorder(ABT::reference_type aboveNode, ABT::reference_type unwindNode, RewriteContext& ctx) { const ReorderDependencies dependencies = computeDependencies(aboveNode, unwindNode, ctx); if (dependencies._hasNodeRef || dependencies._hasNodeAndChildRef) { // Above node refers to projection being unwound. Reject rewrite. return; } defaultReorder(aboveNode, unwindNode, ctx); } template <> struct SubstituteReorder { void operator()(ABT::reference_type aboveNode, ABT::reference_type belowNode, RewriteContext& ctx) const { unwindBelowReorder(aboveNode, belowNode, ctx); } }; template <> struct SubstituteReorder { void operator()(ABT::reference_type aboveNode, ABT::reference_type belowNode, RewriteContext& ctx) const { unwindBelowReorder(aboveNode, belowNode, ctx); } }; template <> struct SubstituteReorder { void operator()(ABT::reference_type aboveNode, ABT::reference_type belowNode, RewriteContext& ctx) const { const ProjectionNameSet& collationProjections = belowNode.cast()->getProperty().getAffectedProjectionNames(); if (collationProjections.find(aboveNode.cast()->getProjectionName()) != collationProjections.cend()) { // A projection being affected by the collation is being unwound. Reject rewrite. return; } defaultReorder(aboveNode, belowNode, ctx); } }; template <> struct SubstituteReorder { void operator()(ABT::reference_type aboveNode, ABT::reference_type belowNode, RewriteContext& ctx) const { defaultPropagateEmptyValueScanNode(belowNode, ctx); } }; template <> struct SubstituteReorder { void operator()(ABT::reference_type aboveNode, ABT::reference_type belowNode, RewriteContext& ctx) const { defaultPropagateEmptyValueScanNode(belowNode, ctx); } }; template <> struct SubstituteReorder { void operator()(ABT::reference_type aboveNode, ABT::reference_type belowNode, RewriteContext& ctx) const { defaultPropagateEmptyValueScanNode(belowNode, ctx); } }; template <> struct SubstituteReorder { void operator()(ABT::reference_type aboveNode, ABT::reference_type belowNode, RewriteContext& ctx) const { defaultPropagateEmptyValueScanNode(belowNode, ctx); } }; template <> struct SubstituteReorder { void operator()(ABT::reference_type aboveNode, ABT::reference_type belowNode, RewriteContext& ctx) const { defaultPropagateEmptyValueScanNode(belowNode, ctx); } }; template <> struct SubstituteReorder { void operator()(ABT::reference_type aboveNode, ABT::reference_type belowNode, RewriteContext& ctx) const { defaultPropagateEmptyValueScanNode(belowNode, ctx); } }; template struct SubstituteMerge { void operator()(ABT::reference_type aboveNode, ABT::reference_type belowNode, RewriteContext& ctx) = delete; }; template <> struct SubstituteMerge { void operator()(ABT::reference_type aboveNode, ABT::reference_type belowNode, RewriteContext& ctx) const { ABT newRoot = aboveNode; // Retain above property. newRoot.cast()->getChild() = belowNode.cast()->getChild(); ctx.addNode(newRoot, true /*substitute*/); } }; template <> struct SubstituteMerge { void operator()(ABT::reference_type aboveNode, ABT::reference_type belowNode, RewriteContext& ctx) const { using namespace properties; ABT newRoot = aboveNode; LimitSkipNode& aboveCollationNode = *newRoot.cast(); const LimitSkipNode& belowCollationNode = *belowNode.cast(); aboveCollationNode.getChild() = belowCollationNode.getChild(); combineLimitSkipProperties(aboveCollationNode.getProperty(), belowCollationNode.getProperty()); ctx.addNode(newRoot, true /*substitute*/); } }; static boost::optional mergeSargableNodes( const properties::IndexingAvailability& indexingAvailability, const MultikeynessTrie& multikeynessTrie, const SargableNode& aboveNode, const SargableNode& belowNode, RewriteContext& ctx) { if (indexingAvailability.getScanGroupId() != belowNode.getChild().cast()->getGroupId()) { // Do not merge if child is not another Sargable node, or the child's child is not a // ScanNode. return {}; } PartialSchemaRequirements mergedReqs = belowNode.getReqMap(); ProjectionRenames projectionRenames; if (!intersectPartialSchemaReq(mergedReqs, aboveNode.getReqMap(), projectionRenames)) { return {}; } const ProjectionName& scanProjName = indexingAvailability.getScanProjection(); bool hasEmptyInterval = simplifyPartialSchemaReqPaths( scanProjName, multikeynessTrie, mergedReqs, ctx.getConstFold()); if (hasEmptyInterval) { return createEmptyValueScanNode(ctx); } if (mergedReqs.size() > SargableNode::kMaxPartialSchemaReqs) { return {}; } const ScanDefinition& scanDef = ctx.getMetadata()._scanDefs.at(indexingAvailability.getScanDefName()); auto candidateIndexes = computeCandidateIndexes(ctx.getPrefixId(), scanProjName, mergedReqs, scanDef, ctx.getHints()._fastIndexNullHandling, hasEmptyInterval, ctx.getConstFold()); if (hasEmptyInterval) { return createEmptyValueScanNode(ctx); } auto scanParams = computeScanParams(ctx.getPrefixId(), mergedReqs, scanProjName); ABT result = make(std::move(mergedReqs), std::move(candidateIndexes), std::move(scanParams), IndexReqTarget::Complete, belowNode.getChild()); applyProjectionRenames(std::move(projectionRenames), result); return result; } template <> struct SubstituteMerge { void operator()(ABT::reference_type aboveNode, ABT::reference_type belowNode, RewriteContext& ctx) const { using namespace properties; const LogicalProps& props = ctx.getAboveLogicalProps(); tassert(6624170, "At this point we should have IndexingAvailability", hasProperty(props)); const auto& indexingAvailability = getPropertyConst(props); const ScanDefinition& scanDef = ctx.getMetadata()._scanDefs.at(indexingAvailability.getScanDefName()); tassert(6624171, "At this point the collection must exist", scanDef.exists()); const auto& result = mergeSargableNodes(indexingAvailability, scanDef.getMultikeynessTrie(), *aboveNode.cast(), *belowNode.cast(), ctx); if (result) { ctx.addNode(*result, true /*substitute*/); } } }; template struct SubstituteConvert { void operator()(ABT::reference_type nodeRef, RewriteContext& ctx) = delete; }; template <> struct SubstituteConvert { void operator()(ABT::reference_type node, RewriteContext& ctx) { if (node.cast()->getProperty().getLimit() == 0) { addEmptyValueScanNode(ctx); } } }; static void convertFilterToSargableNode(ABT::reference_type node, const FilterNode& filterNode, RewriteContext& ctx) { using namespace properties; const LogicalProps& props = ctx.getAboveLogicalProps(); if (!hasProperty(props)) { // Can only convert to sargable node if we have indexing availability. return; } const auto& indexingAvailability = getPropertyConst(props); const ScanDefinition& scanDef = ctx.getMetadata()._scanDefs.at(indexingAvailability.getScanDefName()); if (!scanDef.exists()) { // Do not attempt to optimize for non-existing collections. return; } auto conversion = convertExprToPartialSchemaReq( filterNode.getFilter(), true /*isFilterContext*/, ctx.getPathToInterval()); if (!conversion) { return; } // Remove any partial schema requirements which do not constrain their input. for (auto it = conversion->_reqMap.cbegin(); it != conversion->_reqMap.cend();) { uassert(6624111, "Filter partial schema requirement must contain a variable name.", !it->first._projectionName.empty()); uassert(6624112, "Filter partial schema requirement cannot bind.", !it->second.getBoundProjectionName()); if (isIntervalReqFullyOpenDNF(it->second.getIntervals())) { it = conversion->_reqMap.erase(it); } else { ++it; } } if (conversion->_reqMap.empty()) { // If the filter has no constraints after removing no-ops, then replace with its child. We // need to copy the child since we hold it by reference from the memo, and during // subtitution the current group will be erased. ABT newNode = filterNode.getChild(); ctx.addNode(newNode, true /*substitute*/); return; } const ProjectionName& scanProjName = indexingAvailability.getScanProjection(); bool hasEmptyInterval = simplifyPartialSchemaReqPaths( scanProjName, scanDef.getMultikeynessTrie(), conversion->_reqMap, ctx.getConstFold()); if (hasEmptyInterval) { addEmptyValueScanNode(ctx); return; } if (conversion->_reqMap.size() > SargableNode::kMaxPartialSchemaReqs) { // Too many requirements. return; } auto candidateIndexes = computeCandidateIndexes(ctx.getPrefixId(), scanProjName, conversion->_reqMap, scanDef, ctx.getHints()._fastIndexNullHandling, hasEmptyInterval, ctx.getConstFold()); if (hasEmptyInterval) { addEmptyValueScanNode(ctx); return; } auto scanParams = computeScanParams(ctx.getPrefixId(), conversion->_reqMap, scanProjName); ABT sargableNode = make(std::move(conversion->_reqMap), std::move(candidateIndexes), std::move(scanParams), IndexReqTarget::Complete, filterNode.getChild()); if (conversion->_retainPredicate) { ABT newNode = node; newNode.cast()->getChild() = std::move(sargableNode); ctx.addNode(newNode, true /*substitute*/, true /*addExistingNodeWithNewChild*/); } else { ctx.addNode(sargableNode, true /*substitute*/); } } static ABT appendFieldPath(const FieldPathType& fieldPath, ABT input) { for (size_t index = fieldPath.size(); index-- > 0;) { input = make(fieldPath.at(index), std::move(input)); } return input; } template <> struct SubstituteConvert { void operator()(ABT::reference_type node, RewriteContext& ctx) { const FilterNode& filterNode = *node.cast(); // Sub-rewrite: attempt to de-compose filter. If we have a path with a prefix of PathGet's // followed by a PathComposeM, then split into two filter nodes at the composition and // retain the prefix for each. // TODO: consider using a standalone rewrite. if (auto evalFilter = filterNode.getFilter().cast(); evalFilter != nullptr) { ABT::reference_type pathRef = evalFilter->getPath().ref(); FieldPathType fieldPath; for (;;) { if (auto newPath = pathRef.cast(); newPath != nullptr) { fieldPath.push_back(newPath->name()); pathRef = newPath->getPath().ref(); } else { break; } } if (auto composition = pathRef.cast(); composition != nullptr) { // Remove the path composition and insert two filter nodes. ABT filterNode1 = make( make(appendFieldPath(fieldPath, composition->getPath1()), evalFilter->getInput()), filterNode.getChild()); ABT filterNode2 = make( make(appendFieldPath(fieldPath, composition->getPath2()), evalFilter->getInput()), std::move(filterNode1)); ctx.addNode(filterNode2, true /*substitute*/); return; } } convertFilterToSargableNode(node, filterNode, ctx); } }; template <> struct SubstituteConvert { void operator()(ABT::reference_type node, RewriteContext& ctx) { using namespace properties; const LogicalProps props = ctx.getAboveLogicalProps(); if (!hasProperty(props)) { // Can only convert to sargable node if we have indexing availability. return; } const auto& indexingAvailability = getPropertyConst(props); const ProjectionName& scanProjName = indexingAvailability.getScanProjection(); const ScanDefinition& scanDef = ctx.getMetadata()._scanDefs.at(indexingAvailability.getScanDefName()); if (!scanDef.exists()) { // Do not attempt to optimize for non-existing collections. return; } const EvaluationNode& evalNode = *node.cast(); // Sub-rewrite: attempt to convert Keep to a chain of individual evaluations. // TODO: consider using a standalone rewrite. if (auto evalPathPtr = evalNode.getProjection().cast(); evalPathPtr != nullptr) { if (auto inputPtr = evalPathPtr->getInput().cast(); inputPtr != nullptr && inputPtr->name() == scanProjName) { if (auto pathKeepPtr = evalPathPtr->getPath().cast(); pathKeepPtr != nullptr && pathKeepPtr->getNames().size() < SargableNode::kMaxPartialSchemaReqs) { // Optimization. If we are retaining fields on the root level, generate // EvalNodes with the intention of converting later to a SargableNode after // reordering, in order to be able to cover the fields using a physical scan or // index. ABT result = evalNode.getChild(); ABT keepPath = make(); std::set orderedSet; for (const std::string& fieldName : pathKeepPtr->getNames()) { orderedSet.insert(fieldName); } for (const std::string& fieldName : orderedSet) { ProjectionName projName = ctx.getPrefixId().getNextId("fieldProj"); result = make( projName, make(make(fieldName, make()), evalPathPtr->getInput()), std::move(result)); maybeComposePath(keepPath, make(fieldName, make( make(std::move(projName))))); } result = make( evalNode.getProjectionName(), make(std::move(keepPath), Constant::emptyObject()), std::move(result)); ctx.addNode(result, true /*substitute*/); return; } } } // We still want to extract sargable nodes from EvalNode to use for PhysicalScans. auto conversion = convertExprToPartialSchemaReq( evalNode.getProjection(), false /*isFilterContext*/, ctx.getPathToInterval()); if (!conversion) { return; } uassert(6624165, "Should not be getting retainPredicate set for EvalNodes", !conversion->_retainPredicate); if (conversion->_reqMap.size() != 1) { // For evaluation nodes we expect to create a single entry. return; } for (auto& [key, req] : conversion->_reqMap) { req = { evalNode.getProjectionName(), std::move(req.getIntervals()), req.getIsPerfOnly()}; uassert(6624114, "Eval partial schema requirement must contain a variable name.", !key._projectionName.empty()); uassert(6624115, "Eval partial schema requirement cannot have a range", isIntervalReqFullyOpenDNF(req.getIntervals())); } bool hasEmptyInterval = false; auto candidateIndexes = computeCandidateIndexes(ctx.getPrefixId(), scanProjName, conversion->_reqMap, scanDef, ctx.getHints()._fastIndexNullHandling, hasEmptyInterval, ctx.getConstFold()); if (hasEmptyInterval) { addEmptyValueScanNode(ctx); return; } auto scanParams = computeScanParams(ctx.getPrefixId(), conversion->_reqMap, scanProjName); ABT newNode = make(std::move(conversion->_reqMap), std::move(candidateIndexes), std::move(scanParams), IndexReqTarget::Complete, evalNode.getChild()); ctx.addNode(newNode, true /*substitute*/); } }; static void lowerSargableNode(const SargableNode& node, RewriteContext& ctx) { ABT n = node.getChild(); const auto reqMap = node.getReqMap(); for (const auto& [key, req] : reqMap) { lowerPartialSchemaRequirement(key, req, n, ctx.getPathToInterval()); } ctx.addNode(n, true /*clear*/); } template struct ExploreConvert { void operator()(ABT::reference_type nodeRef, RewriteContext& ctx) = delete; }; struct SplitRequirementsResult { PartialSchemaRequirements _leftReqs; PartialSchemaRequirements _rightReqs; bool _hasFieldCoverage = true; }; /** * Used to split requirements into left and right side. If "isIndex" is false, this is a separation * between "index" and "fetch" predicates, otherwise it is a separation between the two sides of * index intersection. The separation handles cases where we may have intervals which include Null * and return the value, in which case instead of moving the requirement on the left, we insert a * copy on the right side which will fetch the value from the collection. We convert perf-only * requirements to non-perf when inserting on the left under "isIndex", otherwise we drop them. The * mask parameter represents a bitmask indicating which requirements go on the left (bit is 1) and * which go on the right. */ static SplitRequirementsResult splitRequirements( const size_t mask, const bool isIndex, const bool fastIndexNullHandling, const bool disableYieldingTolerantPlans, const std::vector& isFullyOpen, const std::vector& mayReturnNull, const boost::optional>& indexFieldPrefixMapForScanDef, const PartialSchemaRequirements& reqMap) { SplitRequirementsResult result; auto& leftReqs = result._leftReqs; auto& rightReqs = result._rightReqs; const auto addRequirement = [](PartialSchemaRequirements& reqMap, PartialSchemaKey key, boost::optional boundProjectionName, IntervalReqExpr::Node intervals) { // We always strip out the perf-only flag. reqMap.emplace(key, PartialSchemaRequirement{std::move(boundProjectionName), std::move(intervals), false /*isPerfOnly*/}); }; size_t index = 0; for (const auto& [key, req] : reqMap) { if (((1ull << index) & mask) != 0) { bool addedToLeft = false; if (isIndex || fastIndexNullHandling || !mayReturnNull.at(index)) { // We can never return Null values from the requirement. if (isIndex || disableYieldingTolerantPlans || req.getIsPerfOnly()) { // Insert into left side unchanged. addRequirement(leftReqs, key, req.getBoundProjectionName(), req.getIntervals()); } else { // Insert a requirement on the right side too, left side is non-binding. addRequirement( leftReqs, key, boost::none /*boundProjectionName*/, req.getIntervals()); addRequirement( rightReqs, key, req.getBoundProjectionName(), req.getIntervals()); } addedToLeft = true; } else { // At this point we should not be seeing perf-only predicates. invariant(!req.getIsPerfOnly()); // We cannot return index values if our interval can possibly contain Null. Instead, // we remove the output binding for the left side, and return the value from the // right (seek) side. if (!isFullyOpen.at(index)) { addRequirement( leftReqs, key, boost::none /*boundProjectionName*/, req.getIntervals()); addedToLeft = true; } addRequirement(rightReqs, key, req.getBoundProjectionName(), disableYieldingTolerantPlans ? IntervalReqExpr::makeSingularDNF() : req.getIntervals()); } if (addedToLeft) { if (indexFieldPrefixMapForScanDef) { if (auto pathPtr = key._path.cast(); pathPtr != nullptr && indexFieldPrefixMapForScanDef->count(pathPtr->name()) == 0) { // We have found a left requirement which cannot be covered with an // index. result._hasFieldCoverage = false; break; } } } } else if (isIndex || !req.getIsPerfOnly()) { addRequirement(rightReqs, key, req.getBoundProjectionName(), req.getIntervals()); } index++; } return result; } template <> struct ExploreConvert { void operator()(ABT::reference_type node, RewriteContext& ctx) { using namespace properties; const SargableNode& sargableNode = *node.cast(); const IndexReqTarget target = sargableNode.getTarget(); if (target == IndexReqTarget::Seek) { return; } const LogicalProps& props = ctx.getAboveLogicalProps(); const auto& indexingAvailability = getPropertyConst(props); const GroupIdType scanGroupId = indexingAvailability.getScanGroupId(); if (sargableNode.getChild().cast()->getGroupId() != scanGroupId || !ctx.getMemo().getLogicalNodes(scanGroupId).front().is()) { // We are not sitting above a ScanNode. lowerSargableNode(sargableNode, ctx); return; } const std::string& scanDefName = indexingAvailability.getScanDefName(); const ScanDefinition& scanDef = ctx.getMetadata()._scanDefs.at(scanDefName); if (scanDef.getIndexDefs().empty()) { // Do not insert RIDIntersect if we do not have indexes available. return; } const auto aboveNodeId = ctx.getAboveNodeId(); auto& sargableSplitCountMap = ctx.getSargableSplitCountMap(); const size_t splitCount = sargableSplitCountMap[aboveNodeId]; if (splitCount > LogicalRewriter::kMaxSargableNodeSplitCount) { // We cannot split this node further. return; } const ProjectionName& scanProjectionName = indexingAvailability.getScanProjection(); if (collectVariableReferences(node) != VariableNameSetType{scanProjectionName}) { // Rewrite not applicable if we refer projections other than the scan projection. return; } const bool isIndex = target == IndexReqTarget::Index; const auto& indexFieldPrefixMap = ctx.getIndexFieldPrefixMap(); boost::optional> indexFieldPrefixMapForScanDef; if (auto it = indexFieldPrefixMap.find(scanDefName); it != indexFieldPrefixMap.cend() && !isIndex) { indexFieldPrefixMapForScanDef = it->second; } const auto& reqMap = sargableNode.getReqMap(); const bool fastIndexNullHandling = ctx.getHints()._fastIndexNullHandling; const bool disableYieldingTolerantPlans = ctx.getHints()._disableYieldingTolerantPlans; std::vector isFullyOpen; std::vector mayReturnNull; { // Pre-compute if a requirement's interval is fully open. isFullyOpen.reserve(reqMap.size()); for (const auto& [key, req] : reqMap) { isFullyOpen.push_back(isIntervalReqFullyOpenDNF(req.getIntervals())); } if (!fastIndexNullHandling && !isIndex) { // Pre-compute if needed if a requirement's interval may contain nulls, and also has // an output binding. mayReturnNull.reserve(reqMap.size()); for (const auto& [key, req] : reqMap) { mayReturnNull.push_back(req.mayReturnNull(ctx.getConstFold())); } } } // We iterate over the possible ways to split N predicates into 2^N subsets, one goes to the // left, and the other to the right side. If splitting into Index+Seek (isIndex = false), we // try having at least one predicate on the left (mask = 1), and we try all possible // subsets. For index intersection however (isIndex = true), we try symmetric partitioning // (thus the high bound is 2^(N-1)). const size_t reqSize = reqMap.size(); const size_t highMask = isIndex ? (1ull << (reqSize - 1)) : (1ull << reqSize); for (size_t mask = 1; mask < highMask; mask++) { SplitRequirementsResult splitResult = splitRequirements(mask, isIndex, fastIndexNullHandling, disableYieldingTolerantPlans, isFullyOpen, mayReturnNull, indexFieldPrefixMapForScanDef, reqMap); if (splitResult._leftReqs.empty()) { // Can happen if we have intervals containing null. invariant(!fastIndexNullHandling && !isIndex); continue; } // Reject. Must have at least one proper interval on either side. if (isIndex && (!hasProperIntervals(splitResult._leftReqs) || !hasProperIntervals(splitResult._rightReqs))) { continue; } if (!splitResult._hasFieldCoverage) { // Reject rewrite. No suitable indexes. continue; } bool hasEmptyLeftInterval = false; auto leftCandidateIndexes = computeCandidateIndexes(ctx.getPrefixId(), scanProjectionName, splitResult._leftReqs, scanDef, fastIndexNullHandling, hasEmptyLeftInterval, ctx.getConstFold()); if (isIndex && leftCandidateIndexes.empty()) { // Reject rewrite. continue; } bool hasEmptyRightInterval = false; auto rightCandidateIndexes = computeCandidateIndexes(ctx.getPrefixId(), scanProjectionName, splitResult._rightReqs, scanDef, fastIndexNullHandling, hasEmptyRightInterval, ctx.getConstFold()); if (isIndex && rightCandidateIndexes.empty()) { // With empty candidate map, reject only if we cannot implement as Seek. continue; } uassert(6624116, "Empty intervals should already be rewritten to empty ValueScan nodes", !hasEmptyLeftInterval && !hasEmptyRightInterval); ABT scanDelegator = make(scanGroupId); ABT leftChild = make(std::move(splitResult._leftReqs), std::move(leftCandidateIndexes), boost::none, IndexReqTarget::Index, scanDelegator); auto rightScanParams = computeScanParams(ctx.getPrefixId(), splitResult._rightReqs, scanProjectionName); ABT rightChild = splitResult._rightReqs.empty() ? scanDelegator : make(std::move(splitResult._rightReqs), std::move(rightCandidateIndexes), std::move(rightScanParams), isIndex ? IndexReqTarget::Index : IndexReqTarget::Seek, scanDelegator); ABT newRoot = make( scanProjectionName, std::move(leftChild), std::move(rightChild)); const auto& result = ctx.addNode(newRoot, false /*substitute*/); for (const MemoLogicalNodeId nodeId : result.second) { if (!(nodeId == aboveNodeId)) { sargableSplitCountMap[nodeId] = splitCount + 1; } } } } }; template <> struct ExploreConvert { void operator()(ABT::reference_type node, RewriteContext& ctx) { const GroupByNode& groupByNode = *node.cast(); if (groupByNode.getType() != GroupNodeType::Complete) { return; } ProjectionNameVector preaggVariableNames; ABTVector preaggExpressions; const ABTVector& aggExpressions = groupByNode.getAggregationExpressions(); for (const ABT& expr : aggExpressions) { const FunctionCall* aggPtr = expr.cast(); if (aggPtr == nullptr) { return; } // In order to be able to pre-aggregate for now we expect a simple aggregate like // SUM(x). const auto& aggFnName = aggPtr->name(); if (aggFnName != "$sum" && aggFnName != "$min" && aggFnName != "$max") { // TODO: allow more functions. return; } uassert(6624117, "Invalid argument count", aggPtr->nodes().size() == 1); preaggVariableNames.push_back(ctx.getPrefixId().getNextId("preagg")); preaggExpressions.emplace_back( make(aggFnName, makeSeq(make(preaggVariableNames.back())))); } ABT localGroupBy = make(groupByNode.getGroupByProjectionNames(), std::move(preaggVariableNames), aggExpressions, GroupNodeType::Local, groupByNode.getChild()); ABT newRoot = make(groupByNode.getGroupByProjectionNames(), groupByNode.getAggregationProjectionNames(), std::move(preaggExpressions), GroupNodeType::Global, std::move(localGroupBy)); ctx.addNode(newRoot, false /*substitute*/); } }; template struct ExploreReorder { void operator()(ABT::reference_type aboveNode, ABT::reference_type belowNode, RewriteContext& ctx) const = delete; }; template void reorderAgainstRIDIntersectNode(ABT::reference_type aboveNode, ABT::reference_type belowNode, RewriteContext& ctx) { const ReorderDependencies leftDeps = computeDependencies( aboveNode, belowNode, ctx); uassert(6624118, "RIDIntersect cannot bind projections", !leftDeps._hasNodeRef); const bool hasLeftRef = leftDeps._hasChildRef; const ReorderDependencies rightDeps = computeDependencies( aboveNode, belowNode, ctx); uassert(6624119, "RIDIntersect cannot bind projections", !rightDeps._hasNodeRef); const bool hasRightRef = rightDeps._hasChildRef; if (hasLeftRef == hasRightRef) { // Both left and right reorderings available means that we refer to both left and right // sides. return; } const RIDIntersectNode& node = *belowNode.cast(); const GroupIdType groupIdLeft = node.getLeftChild().cast()->getGroupId(); const bool hasProperIntervalLeft = properties::getPropertyConst( ctx.getMemo().getLogicalProps(groupIdLeft)) .hasProperInterval(); if (hasProperIntervalLeft && hasLeftRef) { defaultReorder(aboveNode, belowNode, ctx); } const GroupIdType groupIdRight = node.getRightChild().cast()->getGroupId(); const bool hasProperIntervalRight = properties::getPropertyConst( ctx.getMemo().getLogicalProps(groupIdRight)) .hasProperInterval(); if (hasProperIntervalRight && hasRightRef) { defaultReorder(aboveNode, belowNode, ctx); } }; template <> struct ExploreReorder { void operator()(ABT::reference_type aboveNode, ABT::reference_type belowNode, RewriteContext& ctx) const { reorderAgainstRIDIntersectNode(aboveNode, belowNode, ctx); } }; template <> struct ExploreReorder { void operator()(ABT::reference_type aboveNode, ABT::reference_type belowNode, RewriteContext& ctx) const { reorderAgainstRIDIntersectNode(aboveNode, belowNode, ctx); } }; void LogicalRewriter::registerRewrite(const LogicalRewriteType rewriteType, RewriteFn fn) { if (_activeRewriteSet.find(rewriteType) != _activeRewriteSet.cend()) { const bool inserted = _rewriteMap.emplace(rewriteType, fn).second; invariant(inserted); } } void LogicalRewriter::initializeRewrites() { registerRewrite( LogicalRewriteType::FilterEvaluationReorder, &LogicalRewriter::bindAboveBelow); registerRewrite(LogicalRewriteType::FilterCollationReorder, &LogicalRewriter::bindAboveBelow); registerRewrite( LogicalRewriteType::EvaluationCollationReorder, &LogicalRewriter::bindAboveBelow); registerRewrite( LogicalRewriteType::EvaluationLimitSkipReorder, &LogicalRewriter::bindAboveBelow); registerRewrite(LogicalRewriteType::FilterGroupByReorder, &LogicalRewriter::bindAboveBelow); registerRewrite( LogicalRewriteType::GroupCollationReorder, &LogicalRewriter::bindAboveBelow); registerRewrite(LogicalRewriteType::FilterUnwindReorder, &LogicalRewriter::bindAboveBelow); registerRewrite( LogicalRewriteType::EvaluationUnwindReorder, &LogicalRewriter::bindAboveBelow); registerRewrite(LogicalRewriteType::UnwindCollationReorder, &LogicalRewriter::bindAboveBelow); registerRewrite(LogicalRewriteType::FilterExchangeReorder, &LogicalRewriter::bindAboveBelow); registerRewrite( LogicalRewriteType::ExchangeEvaluationReorder, &LogicalRewriter::bindAboveBelow); registerRewrite(LogicalRewriteType::FilterUnionReorder, &LogicalRewriter::bindAboveBelow); registerRewrite( LogicalRewriteType::CollationMerge, &LogicalRewriter::bindAboveBelow); registerRewrite( LogicalRewriteType::LimitSkipMerge, &LogicalRewriter::bindAboveBelow); registerRewrite(LogicalRewriteType::SargableFilterReorder, &LogicalRewriter::bindAboveBelow); registerRewrite( LogicalRewriteType::SargableEvaluationReorder, &LogicalRewriter::bindAboveBelow); registerRewrite(LogicalRewriteType::LimitSkipSubstitute, &LogicalRewriter::bindSingleNode); registerRewrite(LogicalRewriteType::SargableMerge, &LogicalRewriter::bindAboveBelow); registerRewrite(LogicalRewriteType::FilterSubstitute, &LogicalRewriter::bindSingleNode); registerRewrite(LogicalRewriteType::EvaluationSubstitute, &LogicalRewriter::bindSingleNode); registerRewrite(LogicalRewriteType::FilterValueScanPropagate, &LogicalRewriter::bindAboveBelow); registerRewrite( LogicalRewriteType::EvaluationValueScanPropagate, &LogicalRewriter::bindAboveBelow); registerRewrite( LogicalRewriteType::SargableValueScanPropagate, &LogicalRewriter::bindAboveBelow); registerRewrite( LogicalRewriteType::CollationValueScanPropagate, &LogicalRewriter::bindAboveBelow); registerRewrite( LogicalRewriteType::LimitSkipValueScanPropagate, &LogicalRewriter::bindAboveBelow); registerRewrite( LogicalRewriteType::ExchangeValueScanPropagate, &LogicalRewriter::bindAboveBelow); registerRewrite(LogicalRewriteType::GroupByExplore, &LogicalRewriter::bindSingleNode); registerRewrite(LogicalRewriteType::SargableSplit, &LogicalRewriter::bindSingleNode); registerRewrite(LogicalRewriteType::FilterRIDIntersectReorder, &LogicalRewriter::bindAboveBelow); registerRewrite( LogicalRewriteType::EvaluationRIDIntersectReorder, &LogicalRewriter::bindAboveBelow); } bool LogicalRewriter::rewriteToFixPoint() { int iterationCount = 0; while (!_groupsPending.empty()) { iterationCount++; if (_debugInfo.exceedsIterationLimit(iterationCount)) { // Iteration limit exceeded. return false; } const GroupIdType groupId = *_groupsPending.begin(); rewriteGroup(groupId); _groupsPending.erase(groupId); } return true; } void LogicalRewriter::rewriteGroup(const GroupIdType groupId) { auto& queue = _memo.getLogicalRewriteQueue(groupId); while (!queue.empty()) { LogicalRewriteEntry rewriteEntry = std::move(*queue.top()); // TODO: check if rewriteEntry is different than previous (remove duplicates). queue.pop(); _rewriteMap.at(rewriteEntry._type)(this, rewriteEntry._nodeId, rewriteEntry._type); } } template class R> void LogicalRewriter::bindAboveBelow(const MemoLogicalNodeId nodeMemoId, const LogicalRewriteType rule) { // Get a reference to the node instead of the node itself. // Rewrites insert into the memo and can move it. ABT::reference_type node = _memo.getNode(nodeMemoId); const GroupIdType currentGroupId = nodeMemoId._groupId; if (node.is()) { // Try to bind as parent. const GroupIdType targetGroupId = node.cast() ->getChild() .template cast() ->getGroupId(); for (size_t i = 0; i < _memo.getLogicalNodes(targetGroupId).size(); i++) { const MemoLogicalNodeId targetNodeId{targetGroupId, i}; auto targetNode = _memo.getNode(targetNodeId); if (targetNode.is()) { RewriteContext ctx(*this, rule, nodeMemoId, targetNodeId); R()(node, targetNode, ctx); if (ctx.hasSubstituted()) { return; } } } } if (node.is()) { // Try to bind as child. NodeIdSet usageNodeIdSet; { const auto& inputGroupsToNodeId = _memo.getInputGroupsToNodeIdMap(); auto it = inputGroupsToNodeId.find({currentGroupId}); if (it != inputGroupsToNodeId.cend()) { usageNodeIdSet = it->second; } } for (const MemoLogicalNodeId& parentNodeId : usageNodeIdSet) { auto targetNode = _memo.getNode(parentNodeId); if (targetNode.is()) { uassert(6624047, "Parent child groupId mismatch (usage map index incorrect?)", targetNode.cast() ->getChild() .template cast() ->getGroupId() == currentGroupId); RewriteContext ctx(*this, rule, parentNodeId, nodeMemoId); R()(targetNode, node, ctx); if (ctx.hasSubstituted()) { return; } } } } } template class R> void LogicalRewriter::bindSingleNode(const MemoLogicalNodeId nodeMemoId, const LogicalRewriteType rule) { // Get a reference to the node instead of the node itself. // Rewrites insert into the memo and can move it. ABT::reference_type node = _memo.getNode(nodeMemoId); if (node.is()) { RewriteContext ctx(*this, rule, nodeMemoId); R()(node, ctx); } } const LogicalRewriter::RewriteSet& LogicalRewriter::getExplorationSet() { return _explorationSet; } const LogicalRewriter::RewriteSet& LogicalRewriter::getSubstitutionSet() { return _substitutionSet; } } // namespace mongo::optimizer::cascades