summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorMilena Ivanova <milena.ivanova@mongodb.com>2020-10-15 10:47:55 +0000
committerEvergreen Agent <no-reply@evergreen.mongodb.com>2020-11-06 13:24:32 +0000
commita9386b0525171be6313d734a4b4241751a599341 (patch)
tree25db55820c3c1fdd92c1d736af6c2bce2cec0a2f /src
parentd08a42874819786ba4b5457a62fd01c2fbf0a59f (diff)
downloadmongo-a9386b0525171be6313d734a4b4241751a599341.tar.gz
SERVER-51270 Support $setDifference expression in SBE
Diffstat (limited to 'src')
-rw-r--r--src/mongo/db/exec/sbe/expressions/expression.cpp2
-rw-r--r--src/mongo/db/exec/sbe/expressions/sbe_set_expressions_test.cpp47
-rw-r--r--src/mongo/db/exec/sbe/vm/vm.cpp35
-rw-r--r--src/mongo/db/exec/sbe/vm/vm.h3
-rw-r--r--src/mongo/db/query/sbe_stage_builder_expression.cpp11
5 files changed, 93 insertions, 5 deletions
diff --git a/src/mongo/db/exec/sbe/expressions/expression.cpp b/src/mongo/db/exec/sbe/expressions/expression.cpp
index bccec36b846..3eee0f6a57b 100644
--- a/src/mongo/db/exec/sbe/expressions/expression.cpp
+++ b/src/mongo/db/exec/sbe/expressions/expression.cpp
@@ -408,6 +408,8 @@ static stdx::unordered_map<std::string, BuiltinFn> kBuiltinFunctions = {
{"setUnion", BuiltinFn{[](size_t n) { return n > 0; }, vm::Builtin::setUnion, false}},
{"setIntersection",
BuiltinFn{[](size_t n) { return n > 0; }, vm::Builtin::setIntersection, false}},
+ {"setDifference",
+ BuiltinFn{[](size_t n) { return n == 2; }, vm::Builtin::setDifference, false}},
};
/**
diff --git a/src/mongo/db/exec/sbe/expressions/sbe_set_expressions_test.cpp b/src/mongo/db/exec/sbe/expressions/sbe_set_expressions_test.cpp
index 69d2bf99e12..30eb61a7b2d 100644
--- a/src/mongo/db/exec/sbe/expressions/sbe_set_expressions_test.cpp
+++ b/src/mongo/db/exec/sbe/expressions/sbe_set_expressions_test.cpp
@@ -132,4 +132,51 @@ TEST_F(SBEBuiltinSetOpTest, ReturnsNothingSetIntersection) {
slotAccessor2.reset(value::TypeTags::NumberInt32, value::bitcastFrom<int32_t>(21));
runAndAssertNothing(compiledExpr.get());
}
+
+TEST_F(SBEBuiltinSetOpTest, ComputesSetDifference) {
+ value::OwnedValueAccessor slotAccessor1;
+ value::OwnedValueAccessor slotAccessor2;
+ auto arrSlot1 = bindAccessor(&slotAccessor1);
+ auto arrSlot2 = bindAccessor(&slotAccessor2);
+ auto setDiffExpr = sbe::makeE<sbe::EFunction>(
+ "setDifference", sbe::makeEs(makeE<EVariable>(arrSlot1), makeE<EVariable>(arrSlot2)));
+ auto compiledExpr = compileExpression(*setDiffExpr);
+
+ auto [arrTag1, arrVal1] = makeArray(BSON_ARRAY(1 << 2 << 3));
+ slotAccessor1.reset(arrTag1, arrVal1);
+ auto [arrTag2, arrVal2] = makeArray(BSON_ARRAY(2 << 5 << 7));
+ slotAccessor2.reset(arrTag2, arrVal2);
+ auto [resArrTag, resArrVal] = makeArraySet(BSON_ARRAY(1 << 3));
+ value::ValueGuard resGuard(resArrTag, resArrVal);
+ runAndAssertExpression(compiledExpr.get(), {resArrTag, resArrVal});
+
+ std::tie(arrTag1, arrVal1) = makeArray(BSON_ARRAY(1 << 2 << 3 << 1 << 2 << 3));
+ slotAccessor1.reset(arrTag1, arrVal1);
+ std::tie(arrTag2, arrVal2) = makeArray(BSON_ARRAY(2 << 5 << 7));
+ slotAccessor2.reset(arrTag2, arrVal2);
+ runAndAssertExpression(compiledExpr.get(), {resArrTag, resArrVal});
+
+ std::tie(arrTag1, arrVal1) = makeArray(BSON_ARRAY(1 << 2 << 3));
+ slotAccessor1.reset(arrTag1, arrVal1);
+ std::tie(arrTag2, arrVal2) = value::makeNewArray();
+ slotAccessor2.reset(arrTag2, arrVal2);
+ auto [resArrTag1, resArrVal1] = makeArraySet(BSON_ARRAY(1 << 2 << 3));
+ value::ValueGuard resGuard1(resArrTag1, resArrVal1);
+ runAndAssertExpression(compiledExpr.get(), {resArrTag1, resArrVal1});
+}
+
+TEST_F(SBEBuiltinSetOpTest, ReturnsNothingSetDifference) {
+ value::OwnedValueAccessor slotAccessor1;
+ value::OwnedValueAccessor slotAccessor2;
+ auto arrSlot1 = bindAccessor(&slotAccessor1);
+ auto arrSlot2 = bindAccessor(&slotAccessor2);
+ auto setDiffExpr = sbe::makeE<sbe::EFunction>(
+ "setDifference", sbe::makeEs(makeE<EVariable>(arrSlot1), makeE<EVariable>(arrSlot2)));
+ auto compiledExpr = compileExpression(*setDiffExpr);
+
+ auto [arrTag1, arrVal1] = makeArray(BSON_ARRAY(1 << 2));
+ slotAccessor1.reset(arrTag1, arrVal1);
+ slotAccessor2.reset(value::TypeTags::NumberInt32, value::bitcastFrom<int32_t>(125));
+ runAndAssertNothing(compiledExpr.get());
+}
} // namespace mongo::sbe
diff --git a/src/mongo/db/exec/sbe/vm/vm.cpp b/src/mongo/db/exec/sbe/vm/vm.cpp
index 9f1f343f830..cf067c690dc 100644
--- a/src/mongo/db/exec/sbe/vm/vm.cpp
+++ b/src/mongo/db/exec/sbe/vm/vm.cpp
@@ -1856,6 +1856,39 @@ std::tuple<bool, value::TypeTags, value::Value> ByteCode::builtinSetIntersection
return {true, resTag, resVal};
}
+std::tuple<bool, value::TypeTags, value::Value> ByteCode::builtinSetDifference(uint8_t arity) {
+ auto [lhsOwned, lhsTag, lhsVal] = getFromStack(0);
+ auto [rhsOwned, rhsTag, rhsVal] = getFromStack(1);
+
+ if (!value::isArray(lhsTag) || !value::isArray(rhsTag)) {
+ return {false, value::TypeTags::Nothing, 0};
+ }
+
+ auto [resTag, resVal] = value::makeNewArraySet();
+ value::ValueGuard resGuard{resTag, resVal};
+ auto resView = value::getArraySetView(resVal);
+
+ value::ValueSetType setValuesSecondArg;
+ auto rhsIter = value::ArrayEnumerator(rhsTag, rhsVal);
+ while (!rhsIter.atEnd()) {
+ auto [elTag, elVal] = rhsIter.getViewOfValue();
+ setValuesSecondArg.insert({elTag, elVal});
+ rhsIter.advance();
+ }
+
+ auto lhsIter = value::ArrayEnumerator(lhsTag, lhsVal);
+ while (!lhsIter.atEnd()) {
+ auto [elTag, elVal] = lhsIter.getViewOfValue();
+ if (setValuesSecondArg.count({elTag, elVal}) == 0) {
+ auto [copyTag, copyVal] = value::copyValue(elTag, elVal);
+ resView->push_back(copyTag, copyVal);
+ }
+ lhsIter.advance();
+ }
+
+ resGuard.reset();
+ return {true, resTag, resVal};
+}
std::tuple<bool, value::TypeTags, value::Value> ByteCode::dispatchBuiltin(Builtin f,
uint8_t arity) {
@@ -1958,6 +1991,8 @@ std::tuple<bool, value::TypeTags, value::Value> ByteCode::dispatchBuiltin(Builti
return builtinSetUnion(arity);
case Builtin::setIntersection:
return builtinSetIntersection(arity);
+ case Builtin::setDifference:
+ return builtinSetDifference(arity);
}
MONGO_UNREACHABLE;
diff --git a/src/mongo/db/exec/sbe/vm/vm.h b/src/mongo/db/exec/sbe/vm/vm.h
index 4068b92001d..d3d5d360748 100644
--- a/src/mongo/db/exec/sbe/vm/vm.h
+++ b/src/mongo/db/exec/sbe/vm/vm.h
@@ -221,6 +221,7 @@ enum class Builtin : uint8_t {
isTimezone,
setUnion,
setIntersection,
+ setDifference,
};
class CodeFragment {
@@ -541,6 +542,8 @@ private:
std::tuple<bool, value::TypeTags, value::Value> builtinIsTimezone(uint8_t arity);
std::tuple<bool, value::TypeTags, value::Value> builtinSetUnion(uint8_t arity);
std::tuple<bool, value::TypeTags, value::Value> builtinSetIntersection(uint8_t arity);
+ std::tuple<bool, value::TypeTags, value::Value> builtinSetDifference(uint8_t arity);
+
std::tuple<bool, value::TypeTags, value::Value> dispatchBuiltin(Builtin f, uint8_t arity);
std::tuple<bool, value::TypeTags, value::Value> getFromStack(size_t offset) {
diff --git a/src/mongo/db/query/sbe_stage_builder_expression.cpp b/src/mongo/db/query/sbe_stage_builder_expression.cpp
index 450eadbb62c..0ca2094949f 100644
--- a/src/mongo/db/query/sbe_stage_builder_expression.cpp
+++ b/src/mongo/db/query/sbe_stage_builder_expression.cpp
@@ -1938,20 +1938,21 @@ public:
unsupportedExpression(expr->getOpName());
}
void visit(ExpressionSetDifference* expr) final {
- unsupportedExpression(expr->getOpName());
+ invariant(expr->getChildren().size() == 2);
+ generateSetExpression(expr, "setDifference");
}
void visit(ExpressionSetEquals* expr) final {
unsupportedExpression(expr->getOpName());
}
void visit(ExpressionSetIntersection* expr) final {
- generateNarySetExpression(expr, "setIntersection");
+ generateSetExpression(expr, "setIntersection");
}
void visit(ExpressionSetIsSubset* expr) final {
unsupportedExpression(expr->getOpName());
}
void visit(ExpressionSetUnion* expr) final {
- generateNarySetExpression(expr, "setUnion");
+ generateSetExpression(expr, "setUnion");
}
void visit(ExpressionSize* expr) final {
@@ -2496,9 +2497,9 @@ private:
}
/**
- * Generic logic for building N-ary set expressions: setUnion, setIntersection, etc.
+ * Generic logic for building set expressions: setUnion, setIntersection, etc.
*/
- void generateNarySetExpression(Expression* expr, const std::string& setFunction) {
+ void generateSetExpression(Expression* expr, const std::string& setFunction) {
size_t arity = expr->getChildren().size();
_context->ensureArity(arity);
auto frameId = _context->frameIdGenerator->generate();