summaryrefslogtreecommitdiff
path: root/src/mongo/db
diff options
context:
space:
mode:
authorDrew Paroski <drew.paroski@mongodb.com>2022-01-13 22:07:55 +0000
committerEvergreen Agent <no-reply@evergreen.mongodb.com>2022-02-02 21:26:56 +0000
commitc50744d17ab860d890140d22a1a2a07fd7efd3ea (patch)
treecffde5b9ab1b51c73fb3a32143ad2cfce72e1f98 /src/mongo/db
parent6b74d91648e6bacf446aeb84a56fda9bc60df90b (diff)
downloadmongo-c50744d17ab860d890140d22a1a2a07fd7efd3ea.tar.gz
SERVER-62502 Add ABT support for EFail and ENumericConvert
Diffstat (limited to 'src/mongo/db')
-rw-r--r--src/mongo/db/exec/sbe/abt/abt_lower.cpp50
-rw-r--r--src/mongo/db/exec/sbe/abt/sbe_abt_test.cpp99
-rw-r--r--src/mongo/db/pipeline/abt/match_expression_visitor.cpp2
-rw-r--r--src/mongo/db/query/optimizer/syntax/expr.cpp8
-rw-r--r--src/mongo/db/query/optimizer/syntax/expr.h3
5 files changed, 154 insertions, 8 deletions
diff --git a/src/mongo/db/exec/sbe/abt/abt_lower.cpp b/src/mongo/db/exec/sbe/abt/abt_lower.cpp
index 9d806cd5afd..037344c2736 100644
--- a/src/mongo/db/exec/sbe/abt/abt_lower.cpp
+++ b/src/mongo/db/exec/sbe/abt/abt_lower.cpp
@@ -211,14 +211,50 @@ std::unique_ptr<sbe::EExpression> SBEExpressionLowering::transport(
const FunctionCall& fn, std::vector<std::unique_ptr<sbe::EExpression>> args) {
auto name = fn.name();
+ if (name == "fail") {
+ uassert(6250200, "Invalid number of arguments to fail()", fn.nodes().size() == 2);
+ const auto* codeConstPtr = fn.nodes().at(0).cast<Constant>();
+ const auto* messageConstPtr = fn.nodes().at(1).cast<Constant>();
+
+ uassert(6250201,
+ "First argument to fail() must be a 32-bit integer constant",
+ codeConstPtr != nullptr && codeConstPtr->isValueInt32());
+ uassert(6250202,
+ "Second argument to fail() must be a string constant",
+ messageConstPtr != nullptr && messageConstPtr->isString());
+
+ return sbe::makeE<sbe::EFail>(static_cast<ErrorCodes::Error>(codeConstPtr->getValueInt32()),
+ messageConstPtr->getString());
+ }
+
+ if (name == "convert") {
+ uassert(6250203, "Invalid number of arguments to convert()", fn.nodes().size() == 2);
+ const auto* constPtr = fn.nodes().at(1).cast<Constant>();
+
+ uassert(6250204,
+ "Second argument to convert() must be a 32-bit integer constant",
+ constPtr != nullptr && constPtr->isValueInt32());
+ int32_t constVal = constPtr->getValueInt32();
+
+ uassert(6250205,
+ "Second argument to convert() must be a numeric type tag",
+ constVal >= static_cast<int32_t>(std::numeric_limits<uint8_t>::min()) &&
+ constVal <= static_cast<int32_t>(std::numeric_limits<uint8_t>::max()) &&
+ sbe::value::isNumber(static_cast<sbe::value::TypeTags>(constVal)));
+
+ return sbe::makeE<sbe::ENumericConvert>(std::move(args.at(0)),
+ static_cast<sbe::value::TypeTags>(constVal));
+ }
+
if (name == "typeMatch") {
- uassert(6624209, "Invalid number of typeMatch arguments", fn.nodes().size() == 2);
- if (const auto* constPtr = fn.nodes().at(1).cast<Constant>();
- constPtr != nullptr && constPtr->isValueInt64()) {
- return sbe::makeE<sbe::ETypeMatch>(std::move(args.at(0)), constPtr->getValueInt64());
- } else {
- uasserted(6624210, "Second argument of typeMatch must be an integer constant");
- }
+ uassert(6250206, "Invalid number of arguments to typeMatch()", fn.nodes().size() == 2);
+ const auto* constPtr = fn.nodes().at(1).cast<Constant>();
+
+ uassert(6250207,
+ "Second argument to typeMatch() must be a 32-bit integer constant",
+ constPtr != nullptr && constPtr->isValueInt32());
+
+ return sbe::makeE<sbe::ETypeMatch>(std::move(args.at(0)), constPtr->getValueInt32());
}
// TODO - this is an open question how to do the name mappings.
diff --git a/src/mongo/db/exec/sbe/abt/sbe_abt_test.cpp b/src/mongo/db/exec/sbe/abt/sbe_abt_test.cpp
index a8dbb516982..6fe5a0fa16c 100644
--- a/src/mongo/db/exec/sbe/abt/sbe_abt_test.cpp
+++ b/src/mongo/db/exec/sbe/abt/sbe_abt_test.cpp
@@ -238,6 +238,105 @@ TEST_F(ABTSBE, Lower7) {
ASSERT(result);
}
+TEST_F(ABTSBE, LowerFunctionCallFail) {
+ std::string errorMessage = "Error: Bad value 123456789!";
+
+ auto tree =
+ make<FunctionCall>("fail",
+ makeSeq(Constant::int32(static_cast<int32_t>(ErrorCodes::BadValue)),
+ Constant::str(errorMessage)));
+ auto env = VariableEnvironment::build(tree);
+ SlotVarMap map;
+
+ auto expr = SBEExpressionLowering{env, map}.optimize(tree);
+ ASSERT(expr);
+
+ auto compiledExpr = compileExpression(*expr);
+ Status status = Status::OK();
+
+ try {
+ auto [resultTag, resultVal] = runCompiledExpression(compiledExpr.get());
+ sbe::value::releaseValue(resultTag, resultVal);
+ } catch (const DBException& e) {
+ status = e.toStatus();
+ }
+
+ ASSERT(!status.isOK());
+ ASSERT_EQ(status.code(), ErrorCodes::BadValue);
+ ASSERT_EQ(status.reason(), errorMessage);
+}
+
+TEST_F(ABTSBE, LowerFunctionCallConvert) {
+ sbe::value::OwnedValueAccessor inputAccessor;
+ auto slotId = bindAccessor(&inputAccessor);
+ SlotVarMap map;
+ map["inputVar"] = slotId;
+
+ auto tree = make<FunctionCall>(
+ "convert",
+ makeSeq(make<Variable>("inputVar"),
+ Constant::int32(static_cast<int32_t>(sbe::value::TypeTags::NumberInt64))));
+ auto env = VariableEnvironment::build(tree);
+
+ auto expr = SBEExpressionLowering{env, map}.optimize(tree);
+ ASSERT(expr);
+
+ auto compiledExpr = compileExpression(*expr);
+
+ {
+ inputAccessor.reset(sbe::value::TypeTags::NumberDouble,
+ sbe::value::bitcastFrom<double>(42.0));
+ auto [resultTag, resultVal] = runCompiledExpression(compiledExpr.get());
+ sbe::value::ValueGuard guard(resultTag, resultVal);
+ ASSERT_EQ(resultTag, sbe::value::TypeTags::NumberInt64);
+ ASSERT_EQ(sbe::value::bitcastTo<int64_t>(resultVal), 42);
+ }
+
+ {
+ auto [tag, val] = sbe::value::makeCopyDecimal(Decimal128{-73});
+ inputAccessor.reset(tag, val);
+ auto [resultTag, resultVal] = runCompiledExpression(compiledExpr.get());
+ sbe::value::ValueGuard guard(resultTag, resultVal);
+ ASSERT_EQ(resultTag, sbe::value::TypeTags::NumberInt64);
+ ASSERT_EQ(sbe::value::bitcastTo<int64_t>(resultVal), -73);
+ }
+}
+
+TEST_F(ABTSBE, LowerFunctionCallTypeMatch) {
+ sbe::value::OwnedValueAccessor inputAccessor;
+ auto slotId = bindAccessor(&inputAccessor);
+ SlotVarMap map;
+ map["inputVar"] = slotId;
+
+ auto tree = make<FunctionCall>(
+ "typeMatch",
+ makeSeq(make<Variable>("inputVar"),
+ Constant::int32(getBSONTypeMask(sbe::value::TypeTags::NumberInt32) |
+ getBSONTypeMask(sbe::value::TypeTags::NumberInt64) |
+ getBSONTypeMask(sbe::value::TypeTags::NumberDouble) |
+ getBSONTypeMask(sbe::value::TypeTags::NumberDecimal))));
+ auto env = VariableEnvironment::build(tree);
+
+ auto expr = SBEExpressionLowering{env, map}.optimize(tree);
+ ASSERT(expr);
+
+ auto compiledExpr = compileExpression(*expr);
+
+ {
+ inputAccessor.reset(sbe::value::TypeTags::NumberDouble,
+ sbe::value::bitcastFrom<double>(123.0));
+ auto result = runCompiledExpressionPredicate(compiledExpr.get());
+ ASSERT(result);
+ }
+
+ {
+ auto [tag, val] = sbe::value::makeNewString("123");
+ inputAccessor.reset(tag, val);
+ auto result = runCompiledExpressionPredicate(compiledExpr.get());
+ ASSERT(!result);
+ }
+}
+
TEST_F(NodeSBE, Lower1) {
PrefixId prefixId;
Metadata metadata{{}};
diff --git a/src/mongo/db/pipeline/abt/match_expression_visitor.cpp b/src/mongo/db/pipeline/abt/match_expression_visitor.cpp
index 445c2e719ad..6c33e2d66e3 100644
--- a/src/mongo/db/pipeline/abt/match_expression_visitor.cpp
+++ b/src/mongo/db/pipeline/abt/match_expression_visitor.cpp
@@ -293,7 +293,7 @@ public:
lambdaProjName,
make<FunctionCall>("typeMatch",
makeSeq(make<Variable>(lambdaProjName),
- Constant::int64(expr->typeSet().getBSONTypeMask())))));
+ Constant::int32(expr->typeSet().getBSONTypeMask())))));
if (!expr->path().empty()) {
result = generateFieldPath(FieldPath(expr->path().toString()), std::move(result));
diff --git a/src/mongo/db/query/optimizer/syntax/expr.cpp b/src/mongo/db/query/optimizer/syntax/expr.cpp
index 1dcd85de61a..ab79bd69115 100644
--- a/src/mongo/db/query/optimizer/syntax/expr.cpp
+++ b/src/mongo/db/query/optimizer/syntax/expr.cpp
@@ -107,6 +107,14 @@ bool Constant::operator==(const Constant& other) const {
return sbe::value::bitcastTo<int32_t>(compareVal) == 0;
}
+bool Constant::isString() const {
+ return sbe::value::isString(_tag);
+}
+
+StringData Constant::getString() const {
+ return getStringView(_tag, _val);
+}
+
bool Constant::isValueInt64() const {
return _tag == TypeTags::NumberInt64;
}
diff --git a/src/mongo/db/query/optimizer/syntax/expr.h b/src/mongo/db/query/optimizer/syntax/expr.h
index a795625d59c..636cf75fda7 100644
--- a/src/mongo/db/query/optimizer/syntax/expr.h
+++ b/src/mongo/db/query/optimizer/syntax/expr.h
@@ -69,6 +69,9 @@ public:
return std::pair{_tag, _val};
}
+ bool isString() const;
+ StringData getString() const;
+
bool isValueInt64() const;
int64_t getValueInt64() const;