diff options
author | Rui Liu <lriuui0x0@gmail.com> | 2023-01-11 13:49:13 +0000 |
---|---|---|
committer | Evergreen Agent <no-reply@evergreen.mongodb.com> | 2023-01-11 14:24:48 +0000 |
commit | 96baa4302afec09851b3ef78bd7c783a32365ee6 (patch) | |
tree | 73979741d677f4029c42177b621464caada0cc08 /src/mongo/db | |
parent | 37bbca32d5adc3654dff73b94816a7bd56cd9d42 (diff) | |
download | mongo-96baa4302afec09851b3ef78bd7c783a32365ee6.tar.gz |
SERVER-71577 Implement math expressions in ABT
Diffstat (limited to 'src/mongo/db')
4 files changed, 589 insertions, 31 deletions
diff --git a/src/mongo/db/query/sbe_stage_builder_accumulator_test.cpp b/src/mongo/db/query/sbe_stage_builder_accumulator_test.cpp index a028174ae65..8967f05841c 100644 --- a/src/mongo/db/query/sbe_stage_builder_accumulator_test.cpp +++ b/src/mongo/db/query/sbe_stage_builder_accumulator_test.cpp @@ -349,14 +349,14 @@ TEST_F(SbeStageBuilderGroupTest, TestIdNumericExprOnNonNumericData) { << "2"))}; runGroupAggregationToFail( - R"({_id: {"$add": ["$a", "$b"]}})", docs, static_cast<ErrorCodes::Error>(4974201)); + R"({_id: {"$add": ["$a", "$b"]}})", docs, static_cast<ErrorCodes::Error>(7157723)); runGroupAggregationToFail( - R"({_id: {"$multiply": ["$b", 1000]}})", docs, static_cast<ErrorCodes::Error>(5073102)); + R"({_id: {"$multiply": ["$b", 1000]}})", docs, static_cast<ErrorCodes::Error>(7157721)); runGroupAggregationToFail(R"({_id: {"$divide": [{"$multiply": ["$a", 1000]}, "$b"]}})", docs, - static_cast<ErrorCodes::Error>(5073101)); + static_cast<ErrorCodes::Error>(7157719)); } TEST_F(SbeStageBuilderGroupTest, TestIdObjectExpression) { diff --git a/src/mongo/db/query/sbe_stage_builder_expression.cpp b/src/mongo/db/query/sbe_stage_builder_expression.cpp index b5f7b27111d..80e81047af2 100644 --- a/src/mongo/db/query/sbe_stage_builder_expression.cpp +++ b/src/mongo/db/query/sbe_stage_builder_expression.cpp @@ -787,6 +787,10 @@ public: } void visit(const ExpressionAbs* expr) final { + if (_context->hasAllAbtEligibleEntries(1)) { + return visitABT(expr); + } + auto frameId = _context->state.frameId(); auto binds = sbe::makeEs(_context->popExpr()); sbe::EVariable inputRef(frameId, 0); @@ -806,9 +810,31 @@ public: sbe::makeE<sbe::ELocalBind>(frameId, std::move(binds), std::move(absExpr))); } + void visitABT(const ExpressionAbs* expr) { + auto inputName = makeLocalVariableName(_context->state.frameIdGenerator->generate(), 0); + + auto absExpr = buildABTMultiBranchConditional( + ABTCaseValuePair{generateABTNullOrMissing(inputName), optimizer::Constant::null()}, + ABTCaseValuePair{ + generateABTNonNumericCheck(inputName), + makeABTFail(ErrorCodes::Error{7157700}, "$abs only supports numeric types")}, + ABTCaseValuePair{ + generateABTLongLongMinCheck(inputName), + makeABTFail(ErrorCodes::Error{7157701}, "can't take $abs of long long min")}, + makeABTFunction("abs", optimizer::make<optimizer::Variable>(inputName))); + + _context->pushExpr(optimizer::make<optimizer::Let>( + std::move(inputName), _context->popABTExpr(), std::move(absExpr))); + } + void visit(const ExpressionAdd* expr) final { size_t arity = expr->getChildren().size(); _context->ensureArity(arity); + + if (_context->hasAllAbtEligibleEntries(arity)) { + return visitABT(expr); + } + auto frameId = _context->state.frameId(); auto generateNotNumberOrDate = [frameId](const sbe::value::SlotId slotId) { @@ -886,6 +912,83 @@ public: } } + void visitABT(const ExpressionAdd* expr) { + size_t arity = expr->getChildren().size(); + + if (arity == 0) { + // Return a zero constant if the expression has no operand children. + _context->pushExpr(optimizer::Constant::int32(0)); + } else { + optimizer::ABTVector binds; + optimizer::ProjectionNameVector names; + optimizer::ABTVector checkExprsNull; + optimizer::ABTVector checkExprsNotNumberOrDate; + binds.reserve(arity); + names.reserve(arity); + checkExprsNull.reserve(arity); + checkExprsNotNumberOrDate.reserve(arity); + for (size_t idx = 0; idx < arity; ++idx) { + binds.push_back(_context->popABTExpr()); + auto currentName = + makeLocalVariableName(_context->state.frameIdGenerator->generate(), 0); + names.push_back(currentName); + + checkExprsNull.push_back(generateABTNullOrMissing(currentName)); + checkExprsNotNumberOrDate.push_back(optimizer::make<optimizer::BinaryOp>( + optimizer::Operations::And, + makeNot(makeABTFunction("isNumber", + optimizer::make<optimizer::Variable>(currentName))), + makeNot(makeABTFunction("isDate", + optimizer::make<optimizer::Variable>(currentName))))); + } + + // At this point 'binds' vector contains arguments of $add expression in the reversed + // order. We need to reverse it back to perform summation in the right order below. + // Summation in different order can lead to different result because of accumulated + // precision errors from floating point types. + std::reverse(std::begin(binds), std::end(binds)); + + auto checkNullAllArguments = + makeBalancedBooleanOpTree(optimizer::Operations::Or, std::move(checkExprsNull)); + auto checkNotNumberOrDateAllArguments = makeBalancedBooleanOpTree( + optimizer::Operations::Or, std::move(checkExprsNotNumberOrDate)); + auto addOp = optimizer::make<optimizer::Variable>(names[0]); + for (size_t idx = 1; idx < arity; ++idx) { + auto accName = + makeLocalVariableName(_context->state.frameIdGenerator->generate(), 0); + addOp = optimizer::make<optimizer::Let>( + accName, + std::move(addOp), + optimizer::make<optimizer::If>( + optimizer::make<optimizer::BinaryOp>( + optimizer::Operations::And, + makeABTFunction("isDate", + optimizer::make<optimizer::Variable>(accName)), + makeABTFunction("isDate", + optimizer::make<optimizer::Variable>(names[idx]))), + makeABTFail(ErrorCodes::Error(7157722), + "only one date allowed in an $add expression"), + optimizer::make<optimizer::BinaryOp>( + optimizer::Operations::Add, + optimizer::make<optimizer::Variable>(accName), + optimizer::make<optimizer::Variable>(names[idx])))); + } + auto addExpr = buildABTMultiBranchConditional( + ABTCaseValuePair{std::move(checkNullAllArguments), optimizer::Constant::null()}, + ABTCaseValuePair{ + std::move(checkNotNumberOrDateAllArguments), + makeABTFail(ErrorCodes::Error{7157723}, + "only numbers and dates are allowed in an $add expression")}, + std::move(addOp)); + + for (size_t i = 0; i < arity; ++i) { + addExpr = optimizer::make<optimizer::Let>( + std::move(names[i]), std::move(binds[i]), std::move(addExpr)); + } + _context->pushExpr(std::move(addExpr)); + } + } + void visit(const ExpressionAllElementsTrue* expr) final { unsupportedExpression(expr->getOpName()); } @@ -1016,6 +1119,10 @@ public: } void visit(const ExpressionCeil* expr) final { + if (_context->hasAllAbtEligibleEntries(1)) { + return visitABT(expr); + } + auto frameId = _context->state.frameId(); auto binds = sbe::makeEs(_context->popExpr()); sbe::EVariable inputRef(frameId, 0); @@ -1031,6 +1138,19 @@ public: _context->pushExpr( sbe::makeE<sbe::ELocalBind>(frameId, std::move(binds), std::move(ceilExpr))); } + void visitABT(const ExpressionCeil* expr) { + auto inputName = makeLocalVariableName(_context->state.frameIdGenerator->generate(), 0); + + auto ceilExpr = buildABTMultiBranchConditional( + ABTCaseValuePair{generateABTNullOrMissing(inputName), optimizer::Constant::null()}, + ABTCaseValuePair{ + generateABTNonNumericCheck(inputName), + makeABTFail(ErrorCodes::Error{7157702}, "$ceil only supports numeric types")}, + makeABTFunction("ceil", optimizer::make<optimizer::Variable>(inputName))); + + _context->pushExpr(optimizer::make<optimizer::Let>( + std::move(inputName), _context->popABTExpr(), std::move(ceilExpr))); + } void visit(const ExpressionCoerceToBool* expr) final { // Since $coerceToBool is internal-only and there are not yet any input expressions that // generate an ExpressionCoerceToBool expression, we will leave it as unreachable for now. @@ -1329,9 +1449,7 @@ public: std::move(resultName), optimizer::make<optimizer::FunctionCall>("concatArrays", std::move(argVars)), optimizer::make<optimizer::If>( - optimizer::make<optimizer::FunctionCall>( - "exists", - optimizer::ABTVector{optimizer::make<optimizer::Variable>(resultName)}), + makeABTFunction("exists", optimizer::make<optimizer::Variable>(resultName)), optimizer::make<optimizer::Variable>(resultName), std::move(nullOrFailExpr))); @@ -2078,6 +2196,10 @@ public: void visit(const ExpressionDivide* expr) final { _context->ensureArity(2); + if (_context->hasAllAbtEligibleEntries(2)) { + return visitABT(expr); + } + auto rhs = _context->popExpr(); auto lhs = _context->popExpr(); @@ -2105,7 +2227,43 @@ public: _context->pushExpr( sbe::makeE<sbe::ELocalBind>(frameId, std::move(binds), std::move(divideExpr))); } + void visitABT(const ExpressionDivide* expr) { + auto rhs = _context->popABTExpr(); + auto lhs = _context->popABTExpr(); + + auto lhsName = makeLocalVariableName(_context->state.frameIdGenerator->generate(), 0); + auto rhsName = makeLocalVariableName(_context->state.frameIdGenerator->generate(), 0); + + auto checkIsNumber = optimizer::make<optimizer::BinaryOp>( + optimizer::Operations::And, + makeABTFunction("isNumber", optimizer::make<optimizer::Variable>(lhsName)), + makeABTFunction("isNumber", optimizer::make<optimizer::Variable>(rhsName))); + + auto checkIsNullOrMissing = + optimizer::make<optimizer::BinaryOp>(optimizer::Operations::Or, + generateABTNullOrMissing(lhsName), + generateABTNullOrMissing(rhsName)); + + auto divideExpr = buildABTMultiBranchConditional( + ABTCaseValuePair{std::move(checkIsNullOrMissing), optimizer::Constant::null()}, + ABTCaseValuePair{std::move(checkIsNumber), + optimizer::make<optimizer::BinaryOp>( + optimizer::Operations::Div, + optimizer::make<optimizer::Variable>(lhsName), + optimizer::make<optimizer::Variable>(rhsName))}, + makeABTFail(ErrorCodes::Error{7157719}, "$divide only supports numeric types")); + + _context->pushExpr(optimizer::make<optimizer::Let>( + std::move(lhsName), + std::move(lhs), + optimizer::make<optimizer::Let>( + std::move(rhsName), std::move(rhs), std::move(divideExpr)))); + } void visit(const ExpressionExp* expr) final { + if (_context->hasAllAbtEligibleEntries(1)) { + return visitABT(expr); + } + auto frameId = _context->state.frameId(); auto binds = sbe::makeEs(_context->popExpr()); sbe::EVariable inputRef(frameId, 0); @@ -2121,6 +2279,19 @@ public: _context->pushExpr( sbe::makeE<sbe::ELocalBind>(frameId, std::move(binds), std::move(expExpr))); } + void visitABT(const ExpressionExp* expr) { + auto inputName = makeLocalVariableName(_context->state.frameIdGenerator->generate(), 0); + + auto expExpr = buildABTMultiBranchConditional( + ABTCaseValuePair{generateABTNullOrMissing(inputName), optimizer::Constant::null()}, + ABTCaseValuePair{ + generateABTNonNumericCheck(inputName), + makeABTFail(ErrorCodes::Error{7157704}, "$exp only supports numeric types")}, + makeABTFunction("exp", optimizer::make<optimizer::Variable>(inputName))); + + _context->pushExpr(optimizer::make<optimizer::Let>( + std::move(inputName), _context->popABTExpr(), std::move(expExpr))); + } void visit(const ExpressionFieldPath* expr) final { EvalExpr inputExpr; boost::optional<sbe::value::SlotId> topLevelFieldSlot; @@ -2322,6 +2493,10 @@ public: } void visit(const ExpressionFloor* expr) final { + if (_context->hasAllAbtEligibleEntries(1)) { + return visitABT(expr); + } + auto frameId = _context->state.frameId(); auto binds = sbe::makeEs(_context->popExpr()); sbe::EVariable inputRef(frameId, 0); @@ -2337,6 +2512,19 @@ public: _context->pushExpr( sbe::makeE<sbe::ELocalBind>(frameId, std::move(binds), std::move(floorExpr))); } + void visitABT(const ExpressionFloor* expr) { + auto inputName = makeLocalVariableName(_context->state.frameIdGenerator->generate(), 0); + + auto floorExpr = buildABTMultiBranchConditional( + ABTCaseValuePair{generateABTNullOrMissing(inputName), optimizer::Constant::null()}, + ABTCaseValuePair{ + generateABTNonNumericCheck(inputName), + makeABTFail(ErrorCodes::Error{7157703}, "$floor only supports numeric types")}, + makeABTFunction("floor", optimizer::make<optimizer::Variable>(inputName))); + + _context->pushExpr(optimizer::make<optimizer::Let>( + std::move(inputName), _context->popABTExpr(), std::move(floorExpr))); + } void visit(const ExpressionIfNull* expr) final { auto numChildren = expr->getChildren().size(); invariant(numChildren >= 2); @@ -2407,6 +2595,10 @@ public: _context->varsFrameStack.pop(); } void visit(const ExpressionLn* expr) final { + if (_context->hasAllAbtEligibleEntries(1)) { + return visitABT(expr); + } + auto frameId = _context->state.frameId(); auto binds = sbe::makeEs(_context->popExpr()); sbe::EVariable inputRef(frameId, 0); @@ -2430,10 +2622,37 @@ public: _context->pushExpr( sbe::makeE<sbe::ELocalBind>(frameId, std::move(binds), std::move(lnExpr))); } + void visitABT(const ExpressionLn* expr) { + auto inputName = makeLocalVariableName(_context->state.frameIdGenerator->generate(), 0); + + auto lnExpr = buildABTMultiBranchConditional( + ABTCaseValuePair{generateABTNullOrMissing(inputName), optimizer::Constant::null()}, + ABTCaseValuePair{ + generateABTNonNumericCheck(inputName), + makeABTFail(ErrorCodes::Error{7157705}, "$ln only supports numeric types")}, + // Note: In MQL, $ln on a NumberDecimal NaN historically evaluates to a NumberDouble + // NaN. + ABTCaseValuePair{generateABTNaNCheck(inputName), + makeABTFunction("convert", + optimizer::make<optimizer::Variable>(inputName), + optimizer::Constant::int32(static_cast<int32_t>( + sbe::value::TypeTags::NumberDouble)))}, + ABTCaseValuePair{generateABTNonPositiveCheck(inputName), + makeABTFail(ErrorCodes::Error{7157706}, + "$ln's argument must be a positive number")}, + makeABTFunction("ln", optimizer::make<optimizer::Variable>(inputName))); + + _context->pushExpr(optimizer::make<optimizer::Let>( + std::move(inputName), _context->popABTExpr(), std::move(lnExpr))); + } void visit(const ExpressionLog* expr) final { unsupportedExpression(expr->getOpName()); } void visit(const ExpressionLog10* expr) final { + if (_context->hasAllAbtEligibleEntries(1)) { + return visitABT(expr); + } + auto frameId = _context->state.frameId(); auto binds = sbe::makeEs(_context->popExpr()); sbe::EVariable inputRef(frameId, 0); @@ -2457,6 +2676,29 @@ public: _context->pushExpr( sbe::makeE<sbe::ELocalBind>(frameId, std::move(binds), std::move(log10Expr))); } + void visitABT(const ExpressionLog10* expr) { + auto inputName = makeLocalVariableName(_context->state.frameIdGenerator->generate(), 0); + + auto log10Expr = buildABTMultiBranchConditional( + ABTCaseValuePair{generateABTNullOrMissing(inputName), optimizer::Constant::null()}, + ABTCaseValuePair{ + generateABTNonNumericCheck(inputName), + makeABTFail(ErrorCodes::Error{7157707}, "$log10 only supports numeric types")}, + // Note: In MQL, $log10 on a NumberDecimal NaN historically evaluates to a NumberDouble + // NaN. + ABTCaseValuePair{generateABTNaNCheck(inputName), + makeABTFunction("convert", + optimizer::make<optimizer::Variable>(inputName), + optimizer::Constant::int32(static_cast<int32_t>( + sbe::value::TypeTags::NumberDouble)))}, + ABTCaseValuePair{generateABTNonPositiveCheck(inputName), + makeABTFail(ErrorCodes::Error{7157708}, + "$log10's argument must be a positive number")}, + makeABTFunction("log10", optimizer::make<optimizer::Variable>(inputName))); + + _context->pushExpr(optimizer::make<optimizer::Let>( + std::move(inputName), _context->popABTExpr(), std::move(log10Expr))); + } void visit(const ExpressionInternalFLEBetween* expr) final { unsupportedExpression("$_internalFleBetween"); } @@ -2470,6 +2712,10 @@ public: unsupportedExpression("$meta"); } void visit(const ExpressionMod* expr) final { + if (_context->hasAllAbtEligibleEntries(2)) { + return visitABT(expr); + } + auto frameId = _context->state.frameId(); auto rhs = _context->popExpr(); auto lhs = _context->popExpr(); @@ -2514,10 +2760,63 @@ public: _context->pushExpr( sbe::makeE<sbe::ELocalBind>(frameId, std::move(binds), std::move(modExpr))); } + void visitABT(const ExpressionMod* expr) { + auto rhs = _context->popABTExpr(); + auto lhs = _context->popABTExpr(); + auto lhsName = makeLocalVariableName(_context->state.frameIdGenerator->generate(), 0); + auto rhsName = makeLocalVariableName(_context->state.frameIdGenerator->generate(), 0); + + // If the rhs is a small integral double, convert it to int32 to match $mod MQL semantics. + auto numericConvert32 = makeABTFunction( + "convert", + optimizer::make<optimizer::Variable>(rhsName), + optimizer::Constant::int32(static_cast<int32_t>(sbe::value::TypeTags::NumberInt32))); + auto rhsExpr = buildABTMultiBranchConditional( + ABTCaseValuePair{ + optimizer::make<optimizer::BinaryOp>( + optimizer::Operations::And, + makeABTFunction("typeMatch", + optimizer::make<optimizer::Variable>(rhsName), + optimizer::Constant::int32( + getBSONTypeMask(sbe::value::TypeTags::NumberDouble))), + makeNot(makeABTFunction("typeMatch", + optimizer::make<optimizer::Variable>(lhsName), + optimizer::Constant::int32(getBSONTypeMask( + sbe::value::TypeTags::NumberDouble))))), + optimizer::make<optimizer::BinaryOp>( + optimizer::Operations::FillEmpty, + std::move(numericConvert32), + optimizer::make<optimizer::Variable>(rhsName))}, + optimizer::make<optimizer::Variable>(rhsName)); + + auto modExpr = buildABTMultiBranchConditional( + ABTCaseValuePair{ + optimizer::make<optimizer::BinaryOp>(optimizer::Operations::Or, + generateABTNullOrMissing(lhsName), + generateABTNullOrMissing(rhsName)), + optimizer::Constant::null()}, + ABTCaseValuePair{ + optimizer::make<optimizer::BinaryOp>(optimizer::Operations::Or, + generateABTNonNumericCheck(lhsName), + generateABTNonNumericCheck(rhsName)), + makeABTFail(ErrorCodes::Error{7157718}, "$mod only supports numeric types")}, + makeABTFunction( + "mod", optimizer::make<optimizer::Variable>(lhsName), std::move(rhsExpr))); + + _context->pushExpr(optimizer::make<optimizer::Let>( + std::move(lhsName), + std::move(lhs), + optimizer::make<optimizer::Let>( + std::move(rhsName), std::move(rhs), std::move(modExpr)))); + } void visit(const ExpressionMultiply* expr) final { auto arity = expr->getChildren().size(); _context->ensureArity(arity); + if (_context->hasAllAbtEligibleEntries(arity)) { + return visitABT(expr); + } + // Return multiplicative identity if the $multiply expression has no operands. if (arity == 0) { _context->pushExpr(makeConstant(sbe::value::TypeTags::NumberInt32, 1)); @@ -2584,6 +2883,70 @@ public: _context->pushExpr( sbe::makeE<sbe::ELocalBind>(frameId, std::move(binds), std::move(multiplyExpr))); } + void visitABT(const ExpressionMultiply* expr) { + auto arity = expr->getChildren().size(); + + // Return multiplicative identity if the $multiply expression has no operands. + if (arity == 0) { + _context->pushExpr(optimizer::Constant::int32(1)); + return; + } + + optimizer::ABTVector binds; + optimizer::ProjectionNameVector names; + optimizer::ABTVector checkExprsNull; + optimizer::ABTVector checkExprsNumber; + optimizer::ABTVector variables; + binds.reserve(arity); + names.reserve(arity); + variables.reserve(arity); + checkExprsNull.reserve(arity); + checkExprsNumber.reserve(arity); + for (size_t idx = 0; idx < arity; ++idx) { + binds.push_back(_context->popABTExpr()); + auto currentName = + makeLocalVariableName(_context->state.frameIdGenerator->generate(), 0); + names.push_back(currentName); + + checkExprsNull.push_back(generateABTNullOrMissing(currentName)); + checkExprsNumber.push_back( + makeABTFunction("isNumber", optimizer::make<optimizer::Variable>(currentName))); + variables.push_back(optimizer::make<optimizer::Variable>(currentName)); + } + + // At this point 'binds' vector contains arguments of $multiply expression in the reversed + // order. We need to reverse it back to perform multiplication in the right order below. + // Multiplication in different order can lead to different result because of accumulated + // precision errors from floating point types. + std::reverse(std::begin(binds), std::end(binds)); + + auto checkNullAnyArgument = + makeBalancedBooleanOpTree(optimizer::Operations::Or, std::move(checkExprsNull)); + auto checkNumberAllArguments = + makeBalancedBooleanOpTree(optimizer::Operations::And, std::move(checkExprsNumber)); + auto multiplication = std::accumulate(names.begin() + 1, + names.end(), + optimizer::make<optimizer::Variable>(names.front()), + [](auto&& acc, auto&& ex) { + return optimizer::make<optimizer::BinaryOp>( + optimizer::Operations::Mult, + std::move(acc), + optimizer::make<optimizer::Variable>(ex)); + }); + + auto multiplyExpr = buildABTMultiBranchConditional( + ABTCaseValuePair{std::move(checkNullAnyArgument), optimizer::Constant::null()}, + ABTCaseValuePair{std::move(checkNumberAllArguments), std::move(multiplication)}, + makeABTFail(ErrorCodes::Error{7157721}, + "only numbers are allowed in an $multiply expression")); + + for (size_t i = 0; i < arity; ++i) { + multiplyExpr = optimizer::make<optimizer::Let>( + std::move(names[i]), std::move(binds[i]), std::move(multiplyExpr)); + } + + _context->pushExpr(std::move(multiplyExpr)); + } void visit(const ExpressionNot* expr) final { if (_context->hasAllAbtEligibleEntries(1)) { _context->pushExpr(makeNot( @@ -2647,6 +3010,10 @@ public: unsupportedExpression("$pow"); } void visit(const ExpressionRange* expr) final { + if (_context->hasAllAbtEligibleEntries(expr->getChildren().size())) { + return visitABT(expr); + } + auto outerFrameId = _context->state.frameId(); auto innerFrameId = _context->state.frameId(); @@ -2717,6 +3084,106 @@ public: _context->pushExpr(std::move(rangeExpr)); } + void visitABT(const ExpressionRange* expr) { + auto startName = makeLocalVariableName(_context->state.frameIdGenerator->generate(), 0); + auto endName = makeLocalVariableName(_context->state.frameIdGenerator->generate(), 0); + auto stepName = makeLocalVariableName(_context->state.frameIdGenerator->generate(), 0); + + auto convertedStartName = + makeLocalVariableName(_context->state.frameIdGenerator->generate(), 0); + auto convertedEndName = + makeLocalVariableName(_context->state.frameIdGenerator->generate(), 0); + auto convertedStepName = + makeLocalVariableName(_context->state.frameIdGenerator->generate(), 0); + + auto step = expr->getChildren().size() == 3 ? _context->popABTExpr() + : optimizer::Constant::int32(1); + auto end = _context->popABTExpr(); + auto start = _context->popABTExpr(); + + auto rangeExpr = optimizer::make<optimizer::Let>( + std::move(startName), + std::move(start), + optimizer::make<optimizer::Let>( + std::move(endName), + std::move(end), + optimizer::make<optimizer::Let>( + std::move(stepName), + std::move(step), + buildABTMultiBranchConditional( + ABTCaseValuePair{ + generateABTNonNumericCheck(startName), + makeABTFail(ErrorCodes::Error{7157711}, + "$range only supports numeric types for start")}, + ABTCaseValuePair{generateABTNonNumericCheck(endName), + makeABTFail(ErrorCodes::Error{7157712}, + "$range only supports numeric types for end")}, + ABTCaseValuePair{ + generateABTNonNumericCheck(stepName), + makeABTFail(ErrorCodes::Error{7157713}, + "$range only supports numeric types for step")}, + optimizer::make<optimizer::Let>( + std::move(convertedStartName), + makeABTFunction("convert", + optimizer::make<optimizer::Variable>(startName), + optimizer::Constant::int32(static_cast<int32_t>( + sbe::value::TypeTags::NumberInt32))), + optimizer::make<optimizer::Let>( + std::move(convertedEndName), + makeABTFunction("convert", + optimizer::make<optimizer::Variable>(endName), + optimizer::Constant::int32(static_cast<int32_t>( + sbe::value::TypeTags::NumberInt32))), + optimizer::make<optimizer::Let>( + std::move(convertedStepName), + makeABTFunction("convert", + optimizer::make<optimizer::Variable>(stepName), + optimizer::Constant::int32(static_cast<int32_t>( + sbe::value::TypeTags::NumberInt32))), + buildABTMultiBranchConditional( + ABTCaseValuePair{ + makeNot(makeABTFunction( + "exists", + optimizer::make<optimizer::Variable>( + convertedStartName))), + makeABTFail(ErrorCodes::Error{7157714}, + "$range start argument cannot be " + "represented as a 32-bit integer")}, + ABTCaseValuePair{ + makeNot(makeABTFunction( + "exists", + optimizer::make<optimizer::Variable>( + convertedEndName))), + makeABTFail(ErrorCodes::Error{7157715}, + "$range end argument cannot be represented " + "as a 32-bit integer")}, + ABTCaseValuePair{ + makeNot(makeABTFunction( + "exists", + optimizer::make<optimizer::Variable>( + convertedStepName))), + makeABTFail(ErrorCodes::Error{7157716}, + "$range step argument cannot be " + "represented as a 32-bit integer")}, + ABTCaseValuePair{ + optimizer::make<optimizer::BinaryOp>( + optimizer::Operations::Eq, + optimizer::make<optimizer::Variable>( + convertedStepName), + optimizer::Constant::int32(0)), + makeABTFail(ErrorCodes::Error{7157717}, + "$range requires a non-zero step value")}, + makeABTFunction( + "newArrayFromRange", + optimizer::make<optimizer::Variable>( + convertedStartName), + optimizer::make<optimizer::Variable>(convertedEndName), + optimizer::make<optimizer::Variable>( + convertedStepName)))))))))); + + _context->pushExpr(std::move(rangeExpr)); + } + void visit(const ExpressionReduce* expr) final { unsupportedExpression("$reduce"); } @@ -2952,16 +3419,14 @@ public: auto name = makeLocalVariableName(frameId, 0); auto var = optimizer::make<optimizer::Variable>(name); - auto argumentIsNotArray = - makeNot(optimizer::make<optimizer::FunctionCall>("isArray", optimizer::ABTVector{var})); + auto argumentIsNotArray = makeNot(makeABTFunction("isArray", var)); auto exprReverseArr = buildABTMultiBranchConditional( ABTCaseValuePair{generateABTNullOrMissing(name), optimizer::Constant::null()}, ABTCaseValuePair{ std::move(argumentIsNotArray), makeABTFail(ErrorCodes::Error{7158002}, "$reverseArray argument must be an array")}, - optimizer::make<optimizer::FunctionCall>("reverseArray", - optimizer::ABTVector{std::move(var)})); + makeABTFunction("reverseArray", std::move(var))); _context->pushExpr(optimizer::make<optimizer::Let>( std::move(name), std::move(arg), std::move(exprReverseArr))); @@ -3013,8 +3478,7 @@ public: auto collatorVar = collatorSlot.map( [&](auto slotId) { return _context->registerVariable(*collatorSlot); }); - auto argumentIsNotArray = - makeNot(optimizer::make<optimizer::FunctionCall>("isArray", optimizer::ABTVector{var})); + auto argumentIsNotArray = makeNot(makeABTFunction("isArray", var)); optimizer::ABTVector functionArgs{std::move(var), std::move(specConstant)}; if (collatorVar) { @@ -3115,11 +3579,15 @@ public: sbe::makeE<sbe::ELocalBind>(frameId, std::move(binds), std::move(totalSplitFunc))); } void visit(const ExpressionSqrt* expr) final { + if (_context->hasAllAbtEligibleEntries(1)) { + return visitABT(expr); + } + auto frameId = _context->state.frameId(); auto binds = sbe::makeEs(_context->popExpr()); sbe::EVariable inputRef(frameId, 0); - auto lnExpr = buildMultiBranchConditional( + auto sqrtExpr = buildMultiBranchConditional( CaseValuePair{generateNullOrMissing(inputRef), sbe::makeE<sbe::EConstant>(sbe::value::TypeTags::Null, 0)}, CaseValuePair{generateNonNumericCheck(inputRef), @@ -3132,7 +3600,23 @@ public: makeFunction("sqrt", inputRef.clone())); _context->pushExpr( - sbe::makeE<sbe::ELocalBind>(frameId, std::move(binds), std::move(lnExpr))); + sbe::makeE<sbe::ELocalBind>(frameId, std::move(binds), std::move(sqrtExpr))); + } + void visitABT(const ExpressionSqrt* expr) { + auto inputName = makeLocalVariableName(_context->state.frameIdGenerator->generate(), 0); + + auto sqrtExpr = buildABTMultiBranchConditional( + ABTCaseValuePair{generateABTNullOrMissing(inputName), optimizer::Constant::null()}, + ABTCaseValuePair{ + generateABTNonNumericCheck(inputName), + makeABTFail(ErrorCodes::Error{7157709}, "$sqrt only supports numeric types")}, + ABTCaseValuePair{generateABTNegativeCheck(inputName), + makeABTFail(ErrorCodes::Error{7157710}, + "$sqrt's argument must be greater than or equal to 0")}, + makeABTFunction("sqrt", optimizer::make<optimizer::Variable>(inputName))); + + _context->pushExpr(optimizer::make<optimizer::Let>( + std::move(inputName), _context->popABTExpr(), std::move(sqrtExpr))); } void visit(const ExpressionStrcasecmp* expr) final { unsupportedExpression(expr->getOpName()); @@ -3156,6 +3640,10 @@ public: invariant(expr->getChildren().size() == 2); _context->ensureArity(2); + if (_context->hasAllAbtEligibleEntries(2)) { + return visitABT(expr); + } + auto rhs = _context->popExpr(); auto lhs = _context->popExpr(); @@ -3191,6 +3679,48 @@ public: _context->pushExpr( sbe::makeE<sbe::ELocalBind>(frameId, std::move(binds), std::move(subtractExpr))); } + void visitABT(const ExpressionSubtract* expr) { + auto rhs = _context->popABTExpr(); + auto lhs = _context->popABTExpr(); + + auto lhsName = makeLocalVariableName(_context->state.frameIdGenerator->generate(), 0); + auto rhsName = makeLocalVariableName(_context->state.frameIdGenerator->generate(), 0); + + auto checkNullArguments = + optimizer::make<optimizer::BinaryOp>(optimizer::Operations::Or, + generateABTNullOrMissing(lhsName), + generateABTNullOrMissing(rhsName)); + + auto checkArgumentTypes = makeNot(optimizer::make<optimizer::If>( + makeABTFunction("isNumber", optimizer::make<optimizer::Variable>(lhsName)), + makeABTFunction("isNumber", optimizer::make<optimizer::Variable>(rhsName)), + optimizer::make<optimizer::BinaryOp>( + optimizer::Operations::And, + makeABTFunction("isDate", optimizer::make<optimizer::Variable>(lhsName)), + optimizer::make<optimizer::BinaryOp>( + optimizer::Operations::Or, + makeABTFunction("isNumber", optimizer::make<optimizer::Variable>(rhsName)), + makeABTFunction("isDate", optimizer::make<optimizer::Variable>(rhsName)))))); + + auto subtractOp = + optimizer::make<optimizer::BinaryOp>(optimizer::Operations::Sub, + optimizer::make<optimizer::Variable>(lhsName), + optimizer::make<optimizer::Variable>(rhsName)); + auto subtractExpr = buildABTMultiBranchConditional( + ABTCaseValuePair{std::move(checkNullArguments), optimizer::Constant::null()}, + ABTCaseValuePair{ + std::move(checkArgumentTypes), + makeABTFail(ErrorCodes::Error{7157720}, + "Only numbers and dates are allowed in an $subtract expression. To " + "subtract a number from a date, the date must be the first argument.")}, + std::move(subtractOp)); + + _context->pushExpr(optimizer::make<optimizer::Let>( + std::move(lhsName), + std::move(lhs), + optimizer::make<optimizer::Let>( + std::move(rhsName), std::move(rhs), std::move(subtractExpr)))); + } void visit(const ExpressionSwitch* expr) final { visitConditionalExpression(expr); } @@ -4044,11 +4574,11 @@ private: // Add start index operand. if (startIndexName) { - auto numericConvert64 = optimizer::make<optimizer::FunctionCall>( - "convert", - optimizer::ABTVector{optimizer::make<optimizer::Variable>(*startIndexName), - optimizer::Constant::int32( - static_cast<int32_t>(sbe::value::TypeTags::NumberInt64))}); + auto numericConvert64 = + makeABTFunction("convert", + optimizer::make<optimizer::Variable>(*startIndexName), + optimizer::Constant::int32( + static_cast<int32_t>(sbe::value::TypeTags::NumberInt64))); auto checkValidStartIndex = buildABTMultiBranchConditional( ABTCaseValuePair{generateABTNullishOrNotRepresentableInt32Check(*startIndexName), makeABTFail(ErrorCodes::Error{7158003}, @@ -4065,11 +4595,11 @@ private: // Add end index operand. if (endIndexName) { - auto numericConvert64 = optimizer::make<optimizer::FunctionCall>( - "convert", - optimizer::ABTVector{optimizer::make<optimizer::Variable>(*endIndexName), - optimizer::Constant::int32( - static_cast<int32_t>(sbe::value::TypeTags::NumberInt64))}); + auto numericConvert64 = + makeABTFunction("convert", + optimizer::make<optimizer::Variable>(*endIndexName), + optimizer::Constant::int32( + static_cast<int32_t>(sbe::value::TypeTags::NumberInt64))); auto checkValidEndIndex = buildABTMultiBranchConditional( ABTCaseValuePair{generateABTNullishOrNotRepresentableInt32Check(*endIndexName), makeABTFail(ErrorCodes::Error{7158005}, diff --git a/src/mongo/db/query/sbe_stage_builder_helpers.cpp b/src/mongo/db/query/sbe_stage_builder_helpers.cpp index 40fe43f0b6e..5c7c04b3918 100644 --- a/src/mongo/db/query/sbe_stage_builder_helpers.cpp +++ b/src/mongo/db/query/sbe_stage_builder_helpers.cpp @@ -1412,6 +1412,37 @@ optimizer::ABT generateABTNullOrMissing(optimizer::ProjectionName var) { getBSONTypeMask(BSONType::Undefined))})); } +optimizer::ABT generateABTNegativeCheck(optimizer::ProjectionName var) { + return optimizer::make<optimizer::BinaryOp>(optimizer::Operations::Lt, + optimizer::make<optimizer::Variable>(var), + optimizer::Constant::int32(0)); +} + +optimizer::ABT generateABTNonPositiveCheck(optimizer::ProjectionName var) { + return optimizer::make<optimizer::BinaryOp>(optimizer::Operations::Lte, + optimizer::make<optimizer::Variable>(var), + optimizer::Constant::int32(0)); +} + +optimizer::ABT generateABTNonNumericCheck(optimizer::ProjectionName var) { + return makeNot(optimizer::make<optimizer::FunctionCall>( + "isNumber", optimizer::ABTVector{optimizer::make<optimizer::Variable>(var)})); +} + +optimizer::ABT generateABTLongLongMinCheck(optimizer::ProjectionName var) { + return optimizer::make<optimizer::BinaryOp>( + optimizer::Operations::And, + optimizer::make<optimizer::FunctionCall>( + "typeMatch", + optimizer::ABTVector{ + optimizer::make<optimizer::Variable>(var), + optimizer::Constant::int32(getBSONTypeMask(BSONType::NumberLong))}), + optimizer::make<optimizer::BinaryOp>( + optimizer::Operations::Eq, + optimizer::make<optimizer::Variable>(var), + optimizer::Constant::int64(std::numeric_limits<int64_t>::min()))); +} + optimizer::ABT generateABTNonArrayCheck(optimizer::ProjectionName var) { return makeNot(makeABTFunction("isArray", optimizer::make<optimizer::Variable>(var))); } @@ -1437,12 +1468,6 @@ optimizer::ABT generateABTNullishOrNotRepresentableInt32Check(optimizer::Project "exists", optimizer::ABTVector{std::move(numericConvert32)}))); } -optimizer::ABT generateABTNegativeCheck(optimizer::ProjectionName var) { - return optimizer::make<optimizer::BinaryOp>(optimizer::Operations::Lt, - optimizer::make<optimizer::Variable>(var), - optimizer::Constant::int32(0)); -} - optimizer::ABT generateABTNaNCheck(optimizer::ProjectionName var) { return makeABTFunction("isNaN", optimizer::make<optimizer::Variable>(var)); } diff --git a/src/mongo/db/query/sbe_stage_builder_helpers.h b/src/mongo/db/query/sbe_stage_builder_helpers.h index 6b71b266c6b..77c71ead416 100644 --- a/src/mongo/db/query/sbe_stage_builder_helpers.h +++ b/src/mongo/db/query/sbe_stage_builder_helpers.h @@ -1248,11 +1248,14 @@ optimizer::ProjectionName makeVariableName(sbe::value::SlotId slotId); optimizer::ProjectionName makeLocalVariableName(sbe::FrameId frameId, sbe::value::SlotId slotId); optimizer::ABT generateABTNullOrMissing(optimizer::ProjectionName var); +optimizer::ABT generateABTNegativeCheck(optimizer::ProjectionName var); +optimizer::ABT generateABTNonPositiveCheck(optimizer::ProjectionName var); +optimizer::ABT generateABTNonNumericCheck(optimizer::ProjectionName var); +optimizer::ABT generateABTLongLongMinCheck(optimizer::ProjectionName var); optimizer::ABT generateABTNonArrayCheck(optimizer::ProjectionName var); optimizer::ABT generateABTNonObjectCheck(optimizer::ProjectionName var); optimizer::ABT generateABTNonStringCheck(optimizer::ProjectionName var); optimizer::ABT generateABTNullishOrNotRepresentableInt32Check(optimizer::ProjectionName var); -optimizer::ABT generateABTNegativeCheck(optimizer::ProjectionName var); /** * Generates an ABT that checks if the input expression is NaN _assuming that_ it has * already been verified to be numeric. |