diff options
author | Drew Paroski <drew.paroski@mongodb.com> | 2022-08-19 16:00:50 +0000 |
---|---|---|
committer | Evergreen Agent <no-reply@evergreen.mongodb.com> | 2022-08-30 22:05:08 +0000 |
commit | 5ab19c9f44431183462e09c198ef7cb4b4a43bfa (patch) | |
tree | ef60d630e3327089edf8dbdaeeb8773e17a0b618 | |
parent | 3b3266e8202388a76758e951980a55870c70aabf (diff) | |
download | mongo-5ab19c9f44431183462e09c198ef7cb4b4a43bfa.tar.gz |
SERVER-69020 Update sbe_stage_builder_expression.cpp to use traverseP
-rw-r--r-- | buildscripts/gdb/mongo_printers.py | 4 | ||||
-rw-r--r-- | src/mongo/db/exec/sbe/abt/sbe_abt_test.cpp | 8 | ||||
-rw-r--r-- | src/mongo/db/exec/sbe/expressions/expression.cpp | 31 | ||||
-rw-r--r-- | src/mongo/db/exec/sbe/expressions/sbe_lambda_test.cpp | 3 | ||||
-rw-r--r-- | src/mongo/db/exec/sbe/sbe_test.cpp | 15 | ||||
-rw-r--r-- | src/mongo/db/exec/sbe/vm/vm.cpp | 100 | ||||
-rw-r--r-- | src/mongo/db/exec/sbe/vm/vm.h | 15 | ||||
-rw-r--r-- | src/mongo/db/query/optimizer/rewrites/path_lower.cpp | 9 | ||||
-rw-r--r-- | src/mongo/db/query/sbe_stage_builder.cpp | 31 | ||||
-rw-r--r-- | src/mongo/db/query/sbe_stage_builder_expression.cpp | 111 |
10 files changed, 175 insertions, 152 deletions
diff --git a/buildscripts/gdb/mongo_printers.py b/buildscripts/gdb/mongo_printers.py index cd6e4ddbb0a..c6230527678 100644 --- a/buildscripts/gdb/mongo_printers.py +++ b/buildscripts/gdb/mongo_printers.py @@ -795,7 +795,7 @@ class SbeCodeFragmentPrinter(object): # Some instructions have extra arguments, embedded into the ops stream. args = '' - if op_name in ['pushLocalVal', 'pushMoveLocalVal', 'pushLocalLambda', 'traversePConst']: + if op_name in ['pushLocalVal', 'pushMoveLocalVal', 'pushLocalLambda']: args = 'arg: ' + str(read_as_integer(cur_op, int_size)) cur_op += int_size elif op_name in ['jmp', 'jmpTrue', 'jmpNothing']: @@ -829,7 +829,7 @@ class SbeCodeFragmentPrinter(object): elif op_name in ['fillEmptyConst']: args = 'Instruction::Constants: ' + str(read_as_integer(cur_op, uint8_size)) cur_op += uint8_size - elif op_name in ['traverseFConst']: + elif op_name in ['traverseFConst', 'traversePConst']: const_enum = read_as_integer(cur_op, uint8_size) cur_op += uint8_size args = \ diff --git a/src/mongo/db/exec/sbe/abt/sbe_abt_test.cpp b/src/mongo/db/exec/sbe/abt/sbe_abt_test.cpp index 09c5d4f027b..7f5fa5ee1c4 100644 --- a/src/mongo/db/exec/sbe/abt/sbe_abt_test.cpp +++ b/src/mongo/db/exec/sbe/abt/sbe_abt_test.cpp @@ -104,10 +104,10 @@ TEST_F(ABTSBE, Lower4) { auto tree = make<FunctionCall>( "traverseP", - makeSeq( - make<Constant>(tagArr, valArr), - make<LambdaAbstraction>( - "x", make<BinaryOp>(Operations::Add, make<Variable>("x"), Constant::int64(10))))); + makeSeq(make<Constant>(tagArr, valArr), + make<LambdaAbstraction>( + "x", make<BinaryOp>(Operations::Add, make<Variable>("x"), Constant::int64(10))), + Constant::nothing())); auto env = VariableEnvironment::build(tree); SlotVarMap map; diff --git a/src/mongo/db/exec/sbe/expressions/expression.cpp b/src/mongo/db/exec/sbe/expressions/expression.cpp index 185713839f1..a787dea1284 100644 --- a/src/mongo/db/exec/sbe/expressions/expression.cpp +++ b/src/mongo/db/exec/sbe/expressions/expression.cpp @@ -546,7 +546,7 @@ static stdx::unordered_map<std::string, InstrFn> kInstrFunctions = { {"fillEmpty", InstrFn{[](size_t n) { return n == 2; }, &vm::CodeFragment::appendFillEmpty, false}}, {"traverseP", - InstrFn{[](size_t n) { return n == 2; }, &vm::CodeFragment::appendTraverseP, false}}, + InstrFn{[](size_t n) { return n == 3; }, &vm::CodeFragment::appendTraverseP, false}}, {"traverseF", InstrFn{[](size_t n) { return n == 3; }, &vm::CodeFragment::appendTraverseF, false}}, {"setField", @@ -679,21 +679,28 @@ vm::CodeFragment EFunction::compileDirect(CompileCtx& ctx) const { value::bitcastTo<bool>(val) ? vm::Instruction::True : vm::Instruction::False); return code; - } else if (_name == "traverseP" && _nodes[1]->as<ELocalLambda>()) { - auto lambda = _nodes[1]->as<ELocalLambda>(); + } else if (_name == "traverseP" && _nodes[1]->as<ELocalLambda>() && + _nodes[2]->as<EConstant>()) { + auto [tag, val] = _nodes[2]->as<EConstant>()->getConstant(); + if ((tag == value::TypeTags::NumberInt32 && value::bitcastTo<int32_t>(val) == 1) || + tag == value::TypeTags::Nothing) { + auto lambda = _nodes[1]->as<ELocalLambda>(); - auto body = lambda->compileBodyDirect(ctx); - // Jump around the body. - code.appendJump(body.instrs().size()); + auto body = lambda->compileBodyDirect(ctx); + // Jump around the body. + code.appendJump(body.instrs().size()); - // Remember the position and append the body. - auto bodyPosition = code.instrs().size(); - code.appendNoStack(std::move(body)); + // Remember the position and append the body. + auto bodyPosition = code.instrs().size(); + code.appendNoStack(std::move(body)); - code.append(_nodes[0]->compileDirect(ctx)); - code.appendTraverseP(bodyPosition); + code.append(_nodes[0]->compileDirect(ctx)); - return code; + code.appendTraverseP(bodyPosition, + tag == value::TypeTags::Nothing ? vm::Instruction::Nothing + : vm::Instruction::Int32One); + return code; + } } else if (_name == "applyClassicMatcher") { tassert(6681400, "First argument to applyClassicMatcher must be constant", diff --git a/src/mongo/db/exec/sbe/expressions/sbe_lambda_test.cpp b/src/mongo/db/exec/sbe/expressions/sbe_lambda_test.cpp index 2d3d8447aec..0c9a1fc0c8a 100644 --- a/src/mongo/db/exec/sbe/expressions/sbe_lambda_test.cpp +++ b/src/mongo/db/exec/sbe/expressions/sbe_lambda_test.cpp @@ -43,7 +43,8 @@ TEST_F(SBELambdaTest, AddOneToArray) { makeE<ELocalLambda>(frame, makeE<EPrimBinary>(EPrimBinary::Op::add, makeE<EVariable>(frame, 0), - makeE<EConstant>(constTag, constVal))))); + makeE<EConstant>(constTag, constVal))), + makeE<EConstant>(value::TypeTags::Nothing, 0))); auto compiledExpr = compileExpression(*lambdaExpr); auto bsonArr = BSON_ARRAY(1 << 2 << 3); diff --git a/src/mongo/db/exec/sbe/sbe_test.cpp b/src/mongo/db/exec/sbe/sbe_test.cpp index 5a577f02462..c3ffba02394 100644 --- a/src/mongo/db/exec/sbe/sbe_test.cpp +++ b/src/mongo/db/exec/sbe/sbe_test.cpp @@ -468,16 +468,21 @@ TEST(SBEVM, CodeFragmentToStringArgs) { vm::CodeFragment code; std::string toStringPattern{kAddrPattern}; - code.appendFillEmpty(vm::Instruction::True); - toStringPattern += instrPattern("fillEmptyConst", "k: True"); code.appendFillEmpty(vm::Instruction::Null); toStringPattern += instrPattern("fillEmptyConst", "k: Null"); code.appendFillEmpty(vm::Instruction::False); toStringPattern += instrPattern("fillEmptyConst", "k: False"); + code.appendFillEmpty(vm::Instruction::True); + toStringPattern += instrPattern("fillEmptyConst", "k: True"); - code.appendTraverseP(0xAA); - auto offsetP = 0xAA - code.instrs().size(); - toStringPattern += instrPattern("traversePConst", "offset: " + std::to_string(offsetP)); + code.appendTraverseP(0xAA, vm::Instruction::Nothing); + auto offsetP1 = 0xAA - code.instrs().size(); + toStringPattern += + instrPattern("traversePConst", "k: Nothing, offset: " + std::to_string(offsetP1)); + code.appendTraverseP(0xAA, vm::Instruction::Int32One); + auto offsetP2 = 0xAA - code.instrs().size(); + toStringPattern += + instrPattern("traversePConst", "k: 1, offset: " + std::to_string(offsetP2)); code.appendTraverseF(0xBB, vm::Instruction::True); auto offsetF = 0xBB - code.instrs().size(); toStringPattern += diff --git a/src/mongo/db/exec/sbe/vm/vm.cpp b/src/mongo/db/exec/sbe/vm/vm.cpp index f530ca38bc7..bb4125471ca 100644 --- a/src/mongo/db/exec/sbe/vm/vm.cpp +++ b/src/mongo/db/exec/sbe/vm/vm.cpp @@ -113,7 +113,7 @@ int Instruction::stackOffset[Instruction::Tags::lastInstruction] = { -1, // getElement -1, // collComparisonKey -1, // getFieldOrElement - -1, // traverseP + -2, // traverseP 0, // traversePConst -2, // traverseF 0, // traverseFConst @@ -247,8 +247,7 @@ std::string CodeFragment::toString() const { break; } // Instructions with a single integer argument. - case Instruction::pushLocalLambda: - case Instruction::traversePConst: { + case Instruction::pushLocalLambda: { auto offset = readFromMemory<int>(pcPointer); pcPointer += sizeof(offset); ss << "offset: " << offset; @@ -270,6 +269,7 @@ std::string CodeFragment::toString() const { break; } // Instructions with other kinds of arguments. + case Instruction::traversePConst: case Instruction::traverseFConst: { auto k = readFromMemory<Instruction::Constants>(pcPointer); pcPointer += sizeof(k); @@ -637,17 +637,18 @@ void CodeFragment::appendIsRecordId() { appendSimpleInstruction(Instruction::isRecordId); } -void CodeFragment::appendTraverseP(int codePosition) { +void CodeFragment::appendTraverseP(int codePosition, Instruction::Constants k) { Instruction i; i.tag = Instruction::traversePConst; adjustStackSimple(i); - auto size = sizeof(Instruction) + sizeof(codePosition); + auto size = sizeof(Instruction) + sizeof(codePosition) + sizeof(k); auto offset = allocateSpace(size); int codeOffset = codePosition - _instrs.size(); offset += writeToMemory(offset, i); + offset += writeToMemory(offset, k); offset += writeToMemory(offset, codeOffset); } @@ -919,26 +920,38 @@ std::tuple<bool, value::TypeTags, value::Value> ByteCode::getFieldOrElement( void ByteCode::traverseP(const CodeFragment* code) { // Traverse a projection path - evaluate the input lambda on every element of the input array. // The traversal is recursive; i.e. we visit nested arrays if any. + auto [maxDepthOwn, maxDepthTag, maxDepthVal] = getFromStack(0); + popAndReleaseStack(); auto [lamOwn, lamTag, lamVal] = getFromStack(0); popAndReleaseStack(); - if (lamTag != value::TypeTags::LocalLambda) { + if ((maxDepthTag != value::TypeTags::Nothing && maxDepthTag != value::TypeTags::NumberInt32) || + lamTag != value::TypeTags::LocalLambda) { popAndReleaseStack(); pushStack(false, value::TypeTags::Nothing, 0); + return; } + int64_t lamPos = value::bitcastTo<int64_t>(lamVal); + int64_t maxDepth = maxDepthTag == value::TypeTags::NumberInt32 + ? value::bitcastTo<int32_t>(maxDepthVal) + : std::numeric_limits<int64_t>::max(); - traverseP(code, lamPos); + traverseP(code, lamPos, maxDepth); } -void ByteCode::traverseP(const CodeFragment* code, int64_t position) { +void ByteCode::traverseP(const CodeFragment* code, int64_t position, int64_t maxDepth) { auto [own, tag, val] = getFromStack(0); - if (value::isArray(tag)) { + if (value::isArray(tag) && maxDepth > 0) { value::ValueGuard input(own ? tag : value::TypeTags::Nothing, own ? val : 0); popStack(); - traverseP_nested(code, position, tag, val); + if (maxDepth != std::numeric_limits<int64_t>::max()) { + --maxDepth; + } + + traverseP_nested(code, position, tag, val, maxDepth); } else { runLambdaInternal(code, position); } @@ -947,32 +960,37 @@ void ByteCode::traverseP(const CodeFragment* code, int64_t position) { void ByteCode::traverseP_nested(const CodeFragment* code, int64_t position, value::TypeTags tagInput, - value::Value valInput) { - if (value::isArray(tagInput)) { - auto [tagArrOutput, valArrOutput] = value::makeNewArray(); - auto arrOutput = value::getArrayView(valArrOutput); - value::ValueGuard guard{tagInput, valArrOutput}; - - for (value::ArrayEnumerator enumerator(tagInput, valInput); !enumerator.atEnd(); - enumerator.advance()) { - auto [elemTag, elemVal] = enumerator.getViewOfValue(); - traverseP_nested(code, position, elemTag, elemVal); - auto [retOwn, retTag, retVal] = getFromStack(0); - popStack(); - if (!retOwn) { - auto [copyTag, copyVal] = value::copyValue(retTag, retVal); - retTag = copyTag; - retVal = copyVal; - } - arrOutput->push_back(retTag, retVal); + value::Value valInput, + int64_t maxDepth) { + auto decrement = [](int64_t d) { return d == std::numeric_limits<int64_t>::max() ? d : d - 1; }; + + auto [tagArrOutput, valArrOutput] = value::makeNewArray(); + auto arrOutput = value::getArrayView(valArrOutput); + value::ValueGuard guard{tagInput, valArrOutput}; + + for (value::ArrayEnumerator enumerator(tagInput, valInput); !enumerator.atEnd(); + enumerator.advance()) { + auto [elemTag, elemVal] = enumerator.getViewOfValue(); + + if (maxDepth > 0 && value::isArray(elemTag)) { + traverseP_nested(code, position, elemTag, elemVal, decrement(maxDepth)); + } else { + pushStack(false, elemTag, elemVal); + runLambdaInternal(code, position); } - guard.reset(); - pushStack(true, tagArrOutput, valArrOutput); - } else { - pushStack(false, tagInput, valInput); - runLambdaInternal(code, position); + auto [retOwn, retTag, retVal] = getFromStack(0); + popStack(); + if (!retOwn) { + auto [copyTag, copyVal] = value::copyValue(retTag, retVal); + retTag = copyTag; + retVal = copyVal; + } + arrOutput->push_back(retTag, retVal); } + + guard.reset(); + pushStack(true, tagArrOutput, valArrOutput); } void ByteCode::traverseF(const CodeFragment* code) { @@ -986,6 +1004,7 @@ void ByteCode::traverseF(const CodeFragment* code) { if (lamTag != value::TypeTags::LocalLambda) { popAndReleaseStack(); pushStack(false, value::TypeTags::Nothing, 0); + return; } int64_t lamPos = value::bitcastTo<int64_t>(lamVal); @@ -5198,6 +5217,8 @@ void ByteCode::runInternal(const CodeFragment* code, int64_t position) { auto [lhsOwned, lhsTag, lhsVal] = getFromStack(0); if (lhsTag == value::TypeTags::Nothing) { switch (k) { + case Instruction::Nothing: + break; case Instruction::Null: topStack(false, value::TypeTags::Null, 0); break; @@ -5211,6 +5232,11 @@ void ByteCode::runInternal(const CodeFragment* code, int64_t position) { value::TypeTags::Boolean, value::bitcastFrom<bool>(false)); break; + case Instruction::Int32One: + topStack(false, + value::TypeTags::NumberInt32, + value::bitcastFrom<int32_t>(1)); + break; default: MONGO_UNREACHABLE; } @@ -5331,11 +5357,17 @@ void ByteCode::runInternal(const CodeFragment* code, int64_t position) { break; } case Instruction::traversePConst: { + auto k = readFromMemory<Instruction::Constants>(pcPointer); + pcPointer += sizeof(k); + auto offset = readFromMemory<int>(pcPointer); pcPointer += sizeof(offset); auto codePosition = pcPointer - code->instrs().data() + offset; - traverseP(code, codePosition); + traverseP(code, + codePosition, + k == Instruction::Nothing ? std::numeric_limits<int64_t>::max() : 1); + break; } case Instruction::traverseF: { diff --git a/src/mongo/db/exec/sbe/vm/vm.h b/src/mongo/db/exec/sbe/vm/vm.h index 54a529c0a2e..46793fb4f0d 100644 --- a/src/mongo/db/exec/sbe/vm/vm.h +++ b/src/mongo/db/exec/sbe/vm/vm.h @@ -327,19 +327,25 @@ struct Instruction { }; enum Constants : uint8_t { + Nothing, Null, - True, False, + True, + Int32One, }; static const char* toStringConstants(Constants k) { switch (k) { + case Nothing: + return "Nothing"; case Null: return "Null"; case True: return "True"; case False: return "False"; + case Int32One: + return "1"; default: return "unknown"; } @@ -738,7 +744,7 @@ public: void appendTraverseP() { appendSimpleInstruction(Instruction::traverseP); } - void appendTraverseP(int codePosition); + void appendTraverseP(int codePosition, Instruction::Constants k); void appendTraverseF() { appendSimpleInstruction(Instruction::traverseF); } @@ -995,11 +1001,12 @@ private: value::Value fieldValue); void traverseP(const CodeFragment* code); - void traverseP(const CodeFragment* code, int64_t position); + void traverseP(const CodeFragment* code, int64_t position, int64_t maxDepth); void traverseP_nested(const CodeFragment* code, int64_t position, value::TypeTags tag, - value::Value val); + value::Value val, + int64_t maxDepth); void traverseF(const CodeFragment* code); void traverseF(const CodeFragment* code, int64_t position, bool compareArray); diff --git a/src/mongo/db/query/optimizer/rewrites/path_lower.cpp b/src/mongo/db/query/optimizer/rewrites/path_lower.cpp index 31127814729..05a4dfa7d23 100644 --- a/src/mongo/db/query/optimizer/rewrites/path_lower.cpp +++ b/src/mongo/db/query/optimizer/rewrites/path_lower.cpp @@ -160,10 +160,11 @@ void EvalPathLowering::transport(ABT& n, const PathTraverse& p, ABT& inner) { const std::string& name = _prefixId.getNextId("valTraverse"); - n = make<LambdaAbstraction>( - name, - make<FunctionCall>("traverseP", - makeSeq(make<Variable>(name), std::exchange(inner, make<Blackhole>())))); + n = make<LambdaAbstraction>(name, + make<FunctionCall>("traverseP", + makeSeq(make<Variable>(name), + std::exchange(inner, make<Blackhole>()), + Constant::nothing()))); _changed = true; } diff --git a/src/mongo/db/query/sbe_stage_builder.cpp b/src/mongo/db/query/sbe_stage_builder.cpp index fbcc5dd8e12..26bba21fc4c 100644 --- a/src/mongo/db/query/sbe_stage_builder.cpp +++ b/src/mongo/db/query/sbe_stage_builder.cpp @@ -2290,11 +2290,12 @@ EvalStage optimizeFieldPaths(StageBuilderState& state, if (!state.preGeneratedExprs.contains(fieldPathStr)) { auto [curEvalExpr, curEvalStage] = generateExpression( state, fieldExpr, std::move(retEvalStage), optionalRootSlot, nodeId); - auto optionalFieldPathSlot = curEvalExpr.getSlot(); - tassert( - 6089300, "Must have a valid slot for "_format(fieldPathStr), optionalFieldPathSlot); - state.preGeneratedExprs.emplace(fieldPathStr, std::move(curEvalExpr)); - retEvalStage = std::move(curEvalStage); + + auto [slot, stage] = projectEvalExpr( + std::move(curEvalExpr), std::move(curEvalStage), nodeId, state.slotIdGenerator); + + state.preGeneratedExprs.emplace(fieldPathStr, slot); + retEvalStage = std::move(stage); } }); @@ -2344,20 +2345,12 @@ std::tuple<sbe::value::SlotVector, EvalStage, std::unique_ptr<sbe::EExpression>> nodeId, slotIdGenerator); - if (auto optionalSlot = groupByEvalExpr.getSlot(); optionalSlot.has_value()) { - slots.push_back(*optionalSlot); - retEvalStage = std::move(groupByEvalStage); - } else { - // A projection stage is not generated from 'generateExpression'. So, generates one - // and binds the slot to the field name. - auto [slot, tempEvalStage] = projectEvalExpr(groupByEvalExpr.extractExpr(), - std::move(groupByEvalStage), - nodeId, - slotIdGenerator); - slots.push_back(slot); - groupByEvalExpr = makeVariable(slot); - retEvalStage = std::move(tempEvalStage); - } + auto [slot, stage] = projectEvalExpr( + std::move(groupByEvalExpr), std::move(groupByEvalStage), nodeId, slotIdGenerator); + + slots.push_back(slot); + groupByEvalExpr = slot; + retEvalStage = std::move(stage); exprs.emplace_back(makeConstant(fieldName)); exprs.emplace_back(groupByEvalExpr.extractExpr()); diff --git a/src/mongo/db/query/sbe_stage_builder_expression.cpp b/src/mongo/db/query/sbe_stage_builder_expression.cpp index 2eca40f5534..f0d78e1c913 100644 --- a/src/mongo/db/query/sbe_stage_builder_expression.cpp +++ b/src/mongo/db/query/sbe_stage_builder_expression.cpp @@ -147,90 +147,68 @@ struct ExpressionVisitorContext { std::stack<int> filterExprChildrenCounter; }; -std::pair<sbe::value::SlotId, EvalStage> generateTraverseHelper( - EvalStage inputStage, - sbe::value::SlotId inputSlot, +std::unique_ptr<sbe::EExpression> generateTraverseHelper( + const sbe::EVariable& inputVar, const FieldPath& fp, size_t level, - PlanNodeId planNodeId, - sbe::value::SlotIdGenerator* slotIdGenerator) { + sbe::value::FrameIdGenerator* frameIdGenerator) { using namespace std::literals; invariant(level < fp.getPathLength()); - // The field we will be traversing at the current nested level. - auto fieldSlot{slotIdGenerator->generate()}; - - // Generate the projection stage to read a sub-field at the current nested level and bind it - // to 'fieldSlot'. - inputStage = makeProject(std::move(inputStage), - planNodeId, - fieldSlot, - makeFunction("getField"_sd, - sbe::makeE<sbe::EVariable>(inputSlot), - sbe::makeE<sbe::EConstant>(fp.getFieldName(level)))); + // Generate an expression to read a sub-field at the current nested level. + auto fieldName = sbe::makeE<sbe::EConstant>(fp.getFieldName(level)); + auto fieldExpr = makeFunction("getField"_sd, inputVar.clone(), std::move(fieldName)); if (level == fp.getPathLength() - 1) { // For the last level, we can just return the field slot without the need for a // traverse stage. - return {fieldSlot, std::move(inputStage)}; + return fieldExpr; } // Generate nested traversal. - auto [innerResultSlot, innerBranch] = generateTraverseHelper( - makeLimitCoScanStage(planNodeId), fieldSlot, fp, level + 1, planNodeId, slotIdGenerator); + auto lambdaFrameId = frameIdGenerator->generate(); + auto lambdaParam = sbe::EVariable{lambdaFrameId, 0}; + + auto resultExpr = generateTraverseHelper(lambdaParam, fp, level + 1, frameIdGenerator); + + auto lambdaExpr = sbe::makeE<sbe::ELocalLambda>(lambdaFrameId, std::move(resultExpr)); // Generate the traverse stage for the current nested level. - auto outputSlot{slotIdGenerator->generate()}; - return {outputSlot, - makeTraverse(std::move(inputStage), - std::move(innerBranch), - fieldSlot, - outputSlot, - innerResultSlot, - nullptr, - nullptr, - planNodeId, - 1)}; + return makeFunction("traverseP", + std::move(fieldExpr), + std::move(lambdaExpr), + makeConstant(sbe::value::TypeTags::NumberInt32, 1)); } /** * For the given MatchExpression 'expr', generates a path traversal SBE plan stage sub-tree * implementing the comparison expression. */ -std::pair<sbe::value::SlotId, EvalStage> generateTraverse( - EvalStage inputStage, - sbe::value::SlotId inputSlot, - bool expectsDocumentInputOnly, - const FieldPath& fp, - PlanNodeId planNodeId, - sbe::value::SlotIdGenerator* slotIdGenerator) { +std::unique_ptr<sbe::EExpression> generateTraverse(const sbe::EVariable& inputVar, + bool expectsDocumentInputOnly, + const FieldPath& fp, + sbe::value::FrameIdGenerator* frameIdGenerator) { + size_t level = 0; + if (expectsDocumentInputOnly) { - // When we know for sure that 'inputSlot' will be a document and _not_ an array (such as + // When we know for sure that 'inputVar' will be a document and _not_ an array (such as // when traversing the root document), we can generate a simpler expression. - return generateTraverseHelper( - std::move(inputStage), inputSlot, fp, 0, planNodeId, slotIdGenerator); + return generateTraverseHelper(inputVar, fp, level, frameIdGenerator); } else { - // The general case: the value in the 'inputSlot' may be an array that will require + // The general case: the value in the 'inputVar' may be an array that will require // traversal. - auto outputSlot{slotIdGenerator->generate()}; - auto [innerBranchOutputSlot, innerBranch] = - generateTraverseHelper(makeLimitCoScanStage(planNodeId), - inputSlot, - fp, - 0, // level - planNodeId, - slotIdGenerator); - return {outputSlot, - makeTraverse(std::move(inputStage), - std::move(innerBranch), - inputSlot, - outputSlot, - innerBranchOutputSlot, - nullptr, - nullptr, - planNodeId, - 1)}; + auto lambdaFrameId = frameIdGenerator->generate(); + auto lambdaParam = sbe::EVariable{lambdaFrameId, 0}; + + auto resultExpr = generateTraverseHelper(lambdaParam, fp, level, frameIdGenerator); + + auto lambdaExpr = sbe::makeE<sbe::ELocalLambda>(lambdaFrameId, std::move(resultExpr)); + + return makeFunction("traverseP", + inputVar.clone(), + std::move(lambdaExpr), + makeConstant(sbe::value::TypeTags::NumberInt32, 1)); } } @@ -1907,14 +1885,13 @@ public: // Dereference a dotted path, which may contain arrays requiring implicit traversal. const bool expectsDocumentInputOnly = slotId == *(_context->optionalRootSlot); - auto [outputSlot, stage] = generateTraverse(_context->extractCurrentEvalStage(), - slotId, - expectsDocumentInputOnly, - expr->getFieldPathWithoutCurrentPrefix(), - _context->planNodeId, - _context->state.slotIdGenerator); - - _context->pushExpr(outputSlot, std::move(stage)); + + auto resultExpr = generateTraverse(sbe::EVariable{slotId}, + expectsDocumentInputOnly, + expr->getFieldPathWithoutCurrentPrefix(), + _context->state.frameIdGenerator); + + _context->pushExpr(std::move(resultExpr)); } void visit(const ExpressionFilter* expr) final { // Remove index tracking current child of $filter expression, since it is not used anymore. |