summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDrew Paroski <drew.paroski@mongodb.com>2022-08-19 16:00:50 +0000
committerEvergreen Agent <no-reply@evergreen.mongodb.com>2022-08-30 22:05:08 +0000
commit5ab19c9f44431183462e09c198ef7cb4b4a43bfa (patch)
treeef60d630e3327089edf8dbdaeeb8773e17a0b618
parent3b3266e8202388a76758e951980a55870c70aabf (diff)
downloadmongo-5ab19c9f44431183462e09c198ef7cb4b4a43bfa.tar.gz
SERVER-69020 Update sbe_stage_builder_expression.cpp to use traverseP
-rw-r--r--buildscripts/gdb/mongo_printers.py4
-rw-r--r--src/mongo/db/exec/sbe/abt/sbe_abt_test.cpp8
-rw-r--r--src/mongo/db/exec/sbe/expressions/expression.cpp31
-rw-r--r--src/mongo/db/exec/sbe/expressions/sbe_lambda_test.cpp3
-rw-r--r--src/mongo/db/exec/sbe/sbe_test.cpp15
-rw-r--r--src/mongo/db/exec/sbe/vm/vm.cpp100
-rw-r--r--src/mongo/db/exec/sbe/vm/vm.h15
-rw-r--r--src/mongo/db/query/optimizer/rewrites/path_lower.cpp9
-rw-r--r--src/mongo/db/query/sbe_stage_builder.cpp31
-rw-r--r--src/mongo/db/query/sbe_stage_builder_expression.cpp111
10 files changed, 175 insertions, 152 deletions
diff --git a/buildscripts/gdb/mongo_printers.py b/buildscripts/gdb/mongo_printers.py
index cd6e4ddbb0a..c6230527678 100644
--- a/buildscripts/gdb/mongo_printers.py
+++ b/buildscripts/gdb/mongo_printers.py
@@ -795,7 +795,7 @@ class SbeCodeFragmentPrinter(object):
# Some instructions have extra arguments, embedded into the ops stream.
args = ''
- if op_name in ['pushLocalVal', 'pushMoveLocalVal', 'pushLocalLambda', 'traversePConst']:
+ if op_name in ['pushLocalVal', 'pushMoveLocalVal', 'pushLocalLambda']:
args = 'arg: ' + str(read_as_integer(cur_op, int_size))
cur_op += int_size
elif op_name in ['jmp', 'jmpTrue', 'jmpNothing']:
@@ -829,7 +829,7 @@ class SbeCodeFragmentPrinter(object):
elif op_name in ['fillEmptyConst']:
args = 'Instruction::Constants: ' + str(read_as_integer(cur_op, uint8_size))
cur_op += uint8_size
- elif op_name in ['traverseFConst']:
+ elif op_name in ['traverseFConst', 'traversePConst']:
const_enum = read_as_integer(cur_op, uint8_size)
cur_op += uint8_size
args = \
diff --git a/src/mongo/db/exec/sbe/abt/sbe_abt_test.cpp b/src/mongo/db/exec/sbe/abt/sbe_abt_test.cpp
index 09c5d4f027b..7f5fa5ee1c4 100644
--- a/src/mongo/db/exec/sbe/abt/sbe_abt_test.cpp
+++ b/src/mongo/db/exec/sbe/abt/sbe_abt_test.cpp
@@ -104,10 +104,10 @@ TEST_F(ABTSBE, Lower4) {
auto tree = make<FunctionCall>(
"traverseP",
- makeSeq(
- make<Constant>(tagArr, valArr),
- make<LambdaAbstraction>(
- "x", make<BinaryOp>(Operations::Add, make<Variable>("x"), Constant::int64(10)))));
+ makeSeq(make<Constant>(tagArr, valArr),
+ make<LambdaAbstraction>(
+ "x", make<BinaryOp>(Operations::Add, make<Variable>("x"), Constant::int64(10))),
+ Constant::nothing()));
auto env = VariableEnvironment::build(tree);
SlotVarMap map;
diff --git a/src/mongo/db/exec/sbe/expressions/expression.cpp b/src/mongo/db/exec/sbe/expressions/expression.cpp
index 185713839f1..a787dea1284 100644
--- a/src/mongo/db/exec/sbe/expressions/expression.cpp
+++ b/src/mongo/db/exec/sbe/expressions/expression.cpp
@@ -546,7 +546,7 @@ static stdx::unordered_map<std::string, InstrFn> kInstrFunctions = {
{"fillEmpty",
InstrFn{[](size_t n) { return n == 2; }, &vm::CodeFragment::appendFillEmpty, false}},
{"traverseP",
- InstrFn{[](size_t n) { return n == 2; }, &vm::CodeFragment::appendTraverseP, false}},
+ InstrFn{[](size_t n) { return n == 3; }, &vm::CodeFragment::appendTraverseP, false}},
{"traverseF",
InstrFn{[](size_t n) { return n == 3; }, &vm::CodeFragment::appendTraverseF, false}},
{"setField",
@@ -679,21 +679,28 @@ vm::CodeFragment EFunction::compileDirect(CompileCtx& ctx) const {
value::bitcastTo<bool>(val) ? vm::Instruction::True
: vm::Instruction::False);
return code;
- } else if (_name == "traverseP" && _nodes[1]->as<ELocalLambda>()) {
- auto lambda = _nodes[1]->as<ELocalLambda>();
+ } else if (_name == "traverseP" && _nodes[1]->as<ELocalLambda>() &&
+ _nodes[2]->as<EConstant>()) {
+ auto [tag, val] = _nodes[2]->as<EConstant>()->getConstant();
+ if ((tag == value::TypeTags::NumberInt32 && value::bitcastTo<int32_t>(val) == 1) ||
+ tag == value::TypeTags::Nothing) {
+ auto lambda = _nodes[1]->as<ELocalLambda>();
- auto body = lambda->compileBodyDirect(ctx);
- // Jump around the body.
- code.appendJump(body.instrs().size());
+ auto body = lambda->compileBodyDirect(ctx);
+ // Jump around the body.
+ code.appendJump(body.instrs().size());
- // Remember the position and append the body.
- auto bodyPosition = code.instrs().size();
- code.appendNoStack(std::move(body));
+ // Remember the position and append the body.
+ auto bodyPosition = code.instrs().size();
+ code.appendNoStack(std::move(body));
- code.append(_nodes[0]->compileDirect(ctx));
- code.appendTraverseP(bodyPosition);
+ code.append(_nodes[0]->compileDirect(ctx));
- return code;
+ code.appendTraverseP(bodyPosition,
+ tag == value::TypeTags::Nothing ? vm::Instruction::Nothing
+ : vm::Instruction::Int32One);
+ return code;
+ }
} else if (_name == "applyClassicMatcher") {
tassert(6681400,
"First argument to applyClassicMatcher must be constant",
diff --git a/src/mongo/db/exec/sbe/expressions/sbe_lambda_test.cpp b/src/mongo/db/exec/sbe/expressions/sbe_lambda_test.cpp
index 2d3d8447aec..0c9a1fc0c8a 100644
--- a/src/mongo/db/exec/sbe/expressions/sbe_lambda_test.cpp
+++ b/src/mongo/db/exec/sbe/expressions/sbe_lambda_test.cpp
@@ -43,7 +43,8 @@ TEST_F(SBELambdaTest, AddOneToArray) {
makeE<ELocalLambda>(frame,
makeE<EPrimBinary>(EPrimBinary::Op::add,
makeE<EVariable>(frame, 0),
- makeE<EConstant>(constTag, constVal)))));
+ makeE<EConstant>(constTag, constVal))),
+ makeE<EConstant>(value::TypeTags::Nothing, 0)));
auto compiledExpr = compileExpression(*lambdaExpr);
auto bsonArr = BSON_ARRAY(1 << 2 << 3);
diff --git a/src/mongo/db/exec/sbe/sbe_test.cpp b/src/mongo/db/exec/sbe/sbe_test.cpp
index 5a577f02462..c3ffba02394 100644
--- a/src/mongo/db/exec/sbe/sbe_test.cpp
+++ b/src/mongo/db/exec/sbe/sbe_test.cpp
@@ -468,16 +468,21 @@ TEST(SBEVM, CodeFragmentToStringArgs) {
vm::CodeFragment code;
std::string toStringPattern{kAddrPattern};
- code.appendFillEmpty(vm::Instruction::True);
- toStringPattern += instrPattern("fillEmptyConst", "k: True");
code.appendFillEmpty(vm::Instruction::Null);
toStringPattern += instrPattern("fillEmptyConst", "k: Null");
code.appendFillEmpty(vm::Instruction::False);
toStringPattern += instrPattern("fillEmptyConst", "k: False");
+ code.appendFillEmpty(vm::Instruction::True);
+ toStringPattern += instrPattern("fillEmptyConst", "k: True");
- code.appendTraverseP(0xAA);
- auto offsetP = 0xAA - code.instrs().size();
- toStringPattern += instrPattern("traversePConst", "offset: " + std::to_string(offsetP));
+ code.appendTraverseP(0xAA, vm::Instruction::Nothing);
+ auto offsetP1 = 0xAA - code.instrs().size();
+ toStringPattern +=
+ instrPattern("traversePConst", "k: Nothing, offset: " + std::to_string(offsetP1));
+ code.appendTraverseP(0xAA, vm::Instruction::Int32One);
+ auto offsetP2 = 0xAA - code.instrs().size();
+ toStringPattern +=
+ instrPattern("traversePConst", "k: 1, offset: " + std::to_string(offsetP2));
code.appendTraverseF(0xBB, vm::Instruction::True);
auto offsetF = 0xBB - code.instrs().size();
toStringPattern +=
diff --git a/src/mongo/db/exec/sbe/vm/vm.cpp b/src/mongo/db/exec/sbe/vm/vm.cpp
index f530ca38bc7..bb4125471ca 100644
--- a/src/mongo/db/exec/sbe/vm/vm.cpp
+++ b/src/mongo/db/exec/sbe/vm/vm.cpp
@@ -113,7 +113,7 @@ int Instruction::stackOffset[Instruction::Tags::lastInstruction] = {
-1, // getElement
-1, // collComparisonKey
-1, // getFieldOrElement
- -1, // traverseP
+ -2, // traverseP
0, // traversePConst
-2, // traverseF
0, // traverseFConst
@@ -247,8 +247,7 @@ std::string CodeFragment::toString() const {
break;
}
// Instructions with a single integer argument.
- case Instruction::pushLocalLambda:
- case Instruction::traversePConst: {
+ case Instruction::pushLocalLambda: {
auto offset = readFromMemory<int>(pcPointer);
pcPointer += sizeof(offset);
ss << "offset: " << offset;
@@ -270,6 +269,7 @@ std::string CodeFragment::toString() const {
break;
}
// Instructions with other kinds of arguments.
+ case Instruction::traversePConst:
case Instruction::traverseFConst: {
auto k = readFromMemory<Instruction::Constants>(pcPointer);
pcPointer += sizeof(k);
@@ -637,17 +637,18 @@ void CodeFragment::appendIsRecordId() {
appendSimpleInstruction(Instruction::isRecordId);
}
-void CodeFragment::appendTraverseP(int codePosition) {
+void CodeFragment::appendTraverseP(int codePosition, Instruction::Constants k) {
Instruction i;
i.tag = Instruction::traversePConst;
adjustStackSimple(i);
- auto size = sizeof(Instruction) + sizeof(codePosition);
+ auto size = sizeof(Instruction) + sizeof(codePosition) + sizeof(k);
auto offset = allocateSpace(size);
int codeOffset = codePosition - _instrs.size();
offset += writeToMemory(offset, i);
+ offset += writeToMemory(offset, k);
offset += writeToMemory(offset, codeOffset);
}
@@ -919,26 +920,38 @@ std::tuple<bool, value::TypeTags, value::Value> ByteCode::getFieldOrElement(
void ByteCode::traverseP(const CodeFragment* code) {
// Traverse a projection path - evaluate the input lambda on every element of the input array.
// The traversal is recursive; i.e. we visit nested arrays if any.
+ auto [maxDepthOwn, maxDepthTag, maxDepthVal] = getFromStack(0);
+ popAndReleaseStack();
auto [lamOwn, lamTag, lamVal] = getFromStack(0);
popAndReleaseStack();
- if (lamTag != value::TypeTags::LocalLambda) {
+ if ((maxDepthTag != value::TypeTags::Nothing && maxDepthTag != value::TypeTags::NumberInt32) ||
+ lamTag != value::TypeTags::LocalLambda) {
popAndReleaseStack();
pushStack(false, value::TypeTags::Nothing, 0);
+ return;
}
+
int64_t lamPos = value::bitcastTo<int64_t>(lamVal);
+ int64_t maxDepth = maxDepthTag == value::TypeTags::NumberInt32
+ ? value::bitcastTo<int32_t>(maxDepthVal)
+ : std::numeric_limits<int64_t>::max();
- traverseP(code, lamPos);
+ traverseP(code, lamPos, maxDepth);
}
-void ByteCode::traverseP(const CodeFragment* code, int64_t position) {
+void ByteCode::traverseP(const CodeFragment* code, int64_t position, int64_t maxDepth) {
auto [own, tag, val] = getFromStack(0);
- if (value::isArray(tag)) {
+ if (value::isArray(tag) && maxDepth > 0) {
value::ValueGuard input(own ? tag : value::TypeTags::Nothing, own ? val : 0);
popStack();
- traverseP_nested(code, position, tag, val);
+ if (maxDepth != std::numeric_limits<int64_t>::max()) {
+ --maxDepth;
+ }
+
+ traverseP_nested(code, position, tag, val, maxDepth);
} else {
runLambdaInternal(code, position);
}
@@ -947,32 +960,37 @@ void ByteCode::traverseP(const CodeFragment* code, int64_t position) {
void ByteCode::traverseP_nested(const CodeFragment* code,
int64_t position,
value::TypeTags tagInput,
- value::Value valInput) {
- if (value::isArray(tagInput)) {
- auto [tagArrOutput, valArrOutput] = value::makeNewArray();
- auto arrOutput = value::getArrayView(valArrOutput);
- value::ValueGuard guard{tagInput, valArrOutput};
-
- for (value::ArrayEnumerator enumerator(tagInput, valInput); !enumerator.atEnd();
- enumerator.advance()) {
- auto [elemTag, elemVal] = enumerator.getViewOfValue();
- traverseP_nested(code, position, elemTag, elemVal);
- auto [retOwn, retTag, retVal] = getFromStack(0);
- popStack();
- if (!retOwn) {
- auto [copyTag, copyVal] = value::copyValue(retTag, retVal);
- retTag = copyTag;
- retVal = copyVal;
- }
- arrOutput->push_back(retTag, retVal);
+ value::Value valInput,
+ int64_t maxDepth) {
+ auto decrement = [](int64_t d) { return d == std::numeric_limits<int64_t>::max() ? d : d - 1; };
+
+ auto [tagArrOutput, valArrOutput] = value::makeNewArray();
+ auto arrOutput = value::getArrayView(valArrOutput);
+ value::ValueGuard guard{tagInput, valArrOutput};
+
+ for (value::ArrayEnumerator enumerator(tagInput, valInput); !enumerator.atEnd();
+ enumerator.advance()) {
+ auto [elemTag, elemVal] = enumerator.getViewOfValue();
+
+ if (maxDepth > 0 && value::isArray(elemTag)) {
+ traverseP_nested(code, position, elemTag, elemVal, decrement(maxDepth));
+ } else {
+ pushStack(false, elemTag, elemVal);
+ runLambdaInternal(code, position);
}
- guard.reset();
- pushStack(true, tagArrOutput, valArrOutput);
- } else {
- pushStack(false, tagInput, valInput);
- runLambdaInternal(code, position);
+ auto [retOwn, retTag, retVal] = getFromStack(0);
+ popStack();
+ if (!retOwn) {
+ auto [copyTag, copyVal] = value::copyValue(retTag, retVal);
+ retTag = copyTag;
+ retVal = copyVal;
+ }
+ arrOutput->push_back(retTag, retVal);
}
+
+ guard.reset();
+ pushStack(true, tagArrOutput, valArrOutput);
}
void ByteCode::traverseF(const CodeFragment* code) {
@@ -986,6 +1004,7 @@ void ByteCode::traverseF(const CodeFragment* code) {
if (lamTag != value::TypeTags::LocalLambda) {
popAndReleaseStack();
pushStack(false, value::TypeTags::Nothing, 0);
+ return;
}
int64_t lamPos = value::bitcastTo<int64_t>(lamVal);
@@ -5198,6 +5217,8 @@ void ByteCode::runInternal(const CodeFragment* code, int64_t position) {
auto [lhsOwned, lhsTag, lhsVal] = getFromStack(0);
if (lhsTag == value::TypeTags::Nothing) {
switch (k) {
+ case Instruction::Nothing:
+ break;
case Instruction::Null:
topStack(false, value::TypeTags::Null, 0);
break;
@@ -5211,6 +5232,11 @@ void ByteCode::runInternal(const CodeFragment* code, int64_t position) {
value::TypeTags::Boolean,
value::bitcastFrom<bool>(false));
break;
+ case Instruction::Int32One:
+ topStack(false,
+ value::TypeTags::NumberInt32,
+ value::bitcastFrom<int32_t>(1));
+ break;
default:
MONGO_UNREACHABLE;
}
@@ -5331,11 +5357,17 @@ void ByteCode::runInternal(const CodeFragment* code, int64_t position) {
break;
}
case Instruction::traversePConst: {
+ auto k = readFromMemory<Instruction::Constants>(pcPointer);
+ pcPointer += sizeof(k);
+
auto offset = readFromMemory<int>(pcPointer);
pcPointer += sizeof(offset);
auto codePosition = pcPointer - code->instrs().data() + offset;
- traverseP(code, codePosition);
+ traverseP(code,
+ codePosition,
+ k == Instruction::Nothing ? std::numeric_limits<int64_t>::max() : 1);
+
break;
}
case Instruction::traverseF: {
diff --git a/src/mongo/db/exec/sbe/vm/vm.h b/src/mongo/db/exec/sbe/vm/vm.h
index 54a529c0a2e..46793fb4f0d 100644
--- a/src/mongo/db/exec/sbe/vm/vm.h
+++ b/src/mongo/db/exec/sbe/vm/vm.h
@@ -327,19 +327,25 @@ struct Instruction {
};
enum Constants : uint8_t {
+ Nothing,
Null,
- True,
False,
+ True,
+ Int32One,
};
static const char* toStringConstants(Constants k) {
switch (k) {
+ case Nothing:
+ return "Nothing";
case Null:
return "Null";
case True:
return "True";
case False:
return "False";
+ case Int32One:
+ return "1";
default:
return "unknown";
}
@@ -738,7 +744,7 @@ public:
void appendTraverseP() {
appendSimpleInstruction(Instruction::traverseP);
}
- void appendTraverseP(int codePosition);
+ void appendTraverseP(int codePosition, Instruction::Constants k);
void appendTraverseF() {
appendSimpleInstruction(Instruction::traverseF);
}
@@ -995,11 +1001,12 @@ private:
value::Value fieldValue);
void traverseP(const CodeFragment* code);
- void traverseP(const CodeFragment* code, int64_t position);
+ void traverseP(const CodeFragment* code, int64_t position, int64_t maxDepth);
void traverseP_nested(const CodeFragment* code,
int64_t position,
value::TypeTags tag,
- value::Value val);
+ value::Value val,
+ int64_t maxDepth);
void traverseF(const CodeFragment* code);
void traverseF(const CodeFragment* code, int64_t position, bool compareArray);
diff --git a/src/mongo/db/query/optimizer/rewrites/path_lower.cpp b/src/mongo/db/query/optimizer/rewrites/path_lower.cpp
index 31127814729..05a4dfa7d23 100644
--- a/src/mongo/db/query/optimizer/rewrites/path_lower.cpp
+++ b/src/mongo/db/query/optimizer/rewrites/path_lower.cpp
@@ -160,10 +160,11 @@ void EvalPathLowering::transport(ABT& n, const PathTraverse& p, ABT& inner) {
const std::string& name = _prefixId.getNextId("valTraverse");
- n = make<LambdaAbstraction>(
- name,
- make<FunctionCall>("traverseP",
- makeSeq(make<Variable>(name), std::exchange(inner, make<Blackhole>()))));
+ n = make<LambdaAbstraction>(name,
+ make<FunctionCall>("traverseP",
+ makeSeq(make<Variable>(name),
+ std::exchange(inner, make<Blackhole>()),
+ Constant::nothing())));
_changed = true;
}
diff --git a/src/mongo/db/query/sbe_stage_builder.cpp b/src/mongo/db/query/sbe_stage_builder.cpp
index fbcc5dd8e12..26bba21fc4c 100644
--- a/src/mongo/db/query/sbe_stage_builder.cpp
+++ b/src/mongo/db/query/sbe_stage_builder.cpp
@@ -2290,11 +2290,12 @@ EvalStage optimizeFieldPaths(StageBuilderState& state,
if (!state.preGeneratedExprs.contains(fieldPathStr)) {
auto [curEvalExpr, curEvalStage] = generateExpression(
state, fieldExpr, std::move(retEvalStage), optionalRootSlot, nodeId);
- auto optionalFieldPathSlot = curEvalExpr.getSlot();
- tassert(
- 6089300, "Must have a valid slot for "_format(fieldPathStr), optionalFieldPathSlot);
- state.preGeneratedExprs.emplace(fieldPathStr, std::move(curEvalExpr));
- retEvalStage = std::move(curEvalStage);
+
+ auto [slot, stage] = projectEvalExpr(
+ std::move(curEvalExpr), std::move(curEvalStage), nodeId, state.slotIdGenerator);
+
+ state.preGeneratedExprs.emplace(fieldPathStr, slot);
+ retEvalStage = std::move(stage);
}
});
@@ -2344,20 +2345,12 @@ std::tuple<sbe::value::SlotVector, EvalStage, std::unique_ptr<sbe::EExpression>>
nodeId,
slotIdGenerator);
- if (auto optionalSlot = groupByEvalExpr.getSlot(); optionalSlot.has_value()) {
- slots.push_back(*optionalSlot);
- retEvalStage = std::move(groupByEvalStage);
- } else {
- // A projection stage is not generated from 'generateExpression'. So, generates one
- // and binds the slot to the field name.
- auto [slot, tempEvalStage] = projectEvalExpr(groupByEvalExpr.extractExpr(),
- std::move(groupByEvalStage),
- nodeId,
- slotIdGenerator);
- slots.push_back(slot);
- groupByEvalExpr = makeVariable(slot);
- retEvalStage = std::move(tempEvalStage);
- }
+ auto [slot, stage] = projectEvalExpr(
+ std::move(groupByEvalExpr), std::move(groupByEvalStage), nodeId, slotIdGenerator);
+
+ slots.push_back(slot);
+ groupByEvalExpr = slot;
+ retEvalStage = std::move(stage);
exprs.emplace_back(makeConstant(fieldName));
exprs.emplace_back(groupByEvalExpr.extractExpr());
diff --git a/src/mongo/db/query/sbe_stage_builder_expression.cpp b/src/mongo/db/query/sbe_stage_builder_expression.cpp
index 2eca40f5534..f0d78e1c913 100644
--- a/src/mongo/db/query/sbe_stage_builder_expression.cpp
+++ b/src/mongo/db/query/sbe_stage_builder_expression.cpp
@@ -147,90 +147,68 @@ struct ExpressionVisitorContext {
std::stack<int> filterExprChildrenCounter;
};
-std::pair<sbe::value::SlotId, EvalStage> generateTraverseHelper(
- EvalStage inputStage,
- sbe::value::SlotId inputSlot,
+std::unique_ptr<sbe::EExpression> generateTraverseHelper(
+ const sbe::EVariable& inputVar,
const FieldPath& fp,
size_t level,
- PlanNodeId planNodeId,
- sbe::value::SlotIdGenerator* slotIdGenerator) {
+ sbe::value::FrameIdGenerator* frameIdGenerator) {
using namespace std::literals;
invariant(level < fp.getPathLength());
- // The field we will be traversing at the current nested level.
- auto fieldSlot{slotIdGenerator->generate()};
-
- // Generate the projection stage to read a sub-field at the current nested level and bind it
- // to 'fieldSlot'.
- inputStage = makeProject(std::move(inputStage),
- planNodeId,
- fieldSlot,
- makeFunction("getField"_sd,
- sbe::makeE<sbe::EVariable>(inputSlot),
- sbe::makeE<sbe::EConstant>(fp.getFieldName(level))));
+ // Generate an expression to read a sub-field at the current nested level.
+ auto fieldName = sbe::makeE<sbe::EConstant>(fp.getFieldName(level));
+ auto fieldExpr = makeFunction("getField"_sd, inputVar.clone(), std::move(fieldName));
if (level == fp.getPathLength() - 1) {
// For the last level, we can just return the field slot without the need for a
// traverse stage.
- return {fieldSlot, std::move(inputStage)};
+ return fieldExpr;
}
// Generate nested traversal.
- auto [innerResultSlot, innerBranch] = generateTraverseHelper(
- makeLimitCoScanStage(planNodeId), fieldSlot, fp, level + 1, planNodeId, slotIdGenerator);
+ auto lambdaFrameId = frameIdGenerator->generate();
+ auto lambdaParam = sbe::EVariable{lambdaFrameId, 0};
+
+ auto resultExpr = generateTraverseHelper(lambdaParam, fp, level + 1, frameIdGenerator);
+
+ auto lambdaExpr = sbe::makeE<sbe::ELocalLambda>(lambdaFrameId, std::move(resultExpr));
// Generate the traverse stage for the current nested level.
- auto outputSlot{slotIdGenerator->generate()};
- return {outputSlot,
- makeTraverse(std::move(inputStage),
- std::move(innerBranch),
- fieldSlot,
- outputSlot,
- innerResultSlot,
- nullptr,
- nullptr,
- planNodeId,
- 1)};
+ return makeFunction("traverseP",
+ std::move(fieldExpr),
+ std::move(lambdaExpr),
+ makeConstant(sbe::value::TypeTags::NumberInt32, 1));
}
/**
* For the given MatchExpression 'expr', generates a path traversal SBE plan stage sub-tree
* implementing the comparison expression.
*/
-std::pair<sbe::value::SlotId, EvalStage> generateTraverse(
- EvalStage inputStage,
- sbe::value::SlotId inputSlot,
- bool expectsDocumentInputOnly,
- const FieldPath& fp,
- PlanNodeId planNodeId,
- sbe::value::SlotIdGenerator* slotIdGenerator) {
+std::unique_ptr<sbe::EExpression> generateTraverse(const sbe::EVariable& inputVar,
+ bool expectsDocumentInputOnly,
+ const FieldPath& fp,
+ sbe::value::FrameIdGenerator* frameIdGenerator) {
+ size_t level = 0;
+
if (expectsDocumentInputOnly) {
- // When we know for sure that 'inputSlot' will be a document and _not_ an array (such as
+ // When we know for sure that 'inputVar' will be a document and _not_ an array (such as
// when traversing the root document), we can generate a simpler expression.
- return generateTraverseHelper(
- std::move(inputStage), inputSlot, fp, 0, planNodeId, slotIdGenerator);
+ return generateTraverseHelper(inputVar, fp, level, frameIdGenerator);
} else {
- // The general case: the value in the 'inputSlot' may be an array that will require
+ // The general case: the value in the 'inputVar' may be an array that will require
// traversal.
- auto outputSlot{slotIdGenerator->generate()};
- auto [innerBranchOutputSlot, innerBranch] =
- generateTraverseHelper(makeLimitCoScanStage(planNodeId),
- inputSlot,
- fp,
- 0, // level
- planNodeId,
- slotIdGenerator);
- return {outputSlot,
- makeTraverse(std::move(inputStage),
- std::move(innerBranch),
- inputSlot,
- outputSlot,
- innerBranchOutputSlot,
- nullptr,
- nullptr,
- planNodeId,
- 1)};
+ auto lambdaFrameId = frameIdGenerator->generate();
+ auto lambdaParam = sbe::EVariable{lambdaFrameId, 0};
+
+ auto resultExpr = generateTraverseHelper(lambdaParam, fp, level, frameIdGenerator);
+
+ auto lambdaExpr = sbe::makeE<sbe::ELocalLambda>(lambdaFrameId, std::move(resultExpr));
+
+ return makeFunction("traverseP",
+ inputVar.clone(),
+ std::move(lambdaExpr),
+ makeConstant(sbe::value::TypeTags::NumberInt32, 1));
}
}
@@ -1907,14 +1885,13 @@ public:
// Dereference a dotted path, which may contain arrays requiring implicit traversal.
const bool expectsDocumentInputOnly = slotId == *(_context->optionalRootSlot);
- auto [outputSlot, stage] = generateTraverse(_context->extractCurrentEvalStage(),
- slotId,
- expectsDocumentInputOnly,
- expr->getFieldPathWithoutCurrentPrefix(),
- _context->planNodeId,
- _context->state.slotIdGenerator);
-
- _context->pushExpr(outputSlot, std::move(stage));
+
+ auto resultExpr = generateTraverse(sbe::EVariable{slotId},
+ expectsDocumentInputOnly,
+ expr->getFieldPathWithoutCurrentPrefix(),
+ _context->state.frameIdGenerator);
+
+ _context->pushExpr(std::move(resultExpr));
}
void visit(const ExpressionFilter* expr) final {
// Remove index tracking current child of $filter expression, since it is not used anymore.