summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/mongo/db/pipeline/expression.cpp77
-rw-r--r--src/mongo/db/pipeline/expression.h8
2 files changed, 69 insertions, 16 deletions
diff --git a/src/mongo/db/pipeline/expression.cpp b/src/mongo/db/pipeline/expression.cpp
index 3a4a3bf55dd..3ac33bb261e 100644
--- a/src/mongo/db/pipeline/expression.cpp
+++ b/src/mongo/db/pipeline/expression.cpp
@@ -4694,28 +4694,46 @@ void ExpressionSetEquals::validateArguments(const ExpressionVector& args) const
args.size() >= 2);
}
+namespace {
+bool setEqualsHelper(const ValueUnorderedSet& lhs,
+ const ValueUnorderedSet& rhs,
+ const ValueComparator& valueComparator) {
+ if (lhs.size() != rhs.size()) {
+ return false;
+ }
+ for (const auto& entry : lhs) {
+ if (!rhs.count(entry)) {
+ return false;
+ }
+ }
+ return true;
+}
+} // namespace
+
Value ExpressionSetEquals::evaluate(const Document& root, Variables* variables) const {
const size_t n = _children.size();
const auto& valueComparator = getExpressionContext()->getValueComparator();
- ValueSet lhs = valueComparator.makeOrderedValueSet();
- for (size_t i = 0; i < n; i++) {
- const Value nextEntry = _children[i]->evaluate(root, variables);
+ auto evaluateChild = [&](size_t index) {
+ const Value entry = _children[index]->evaluate(root, variables);
uassert(17044,
- str::stream() << "All operands of $setEquals must be arrays. One "
- << "argument is of type: " << typeName(nextEntry.getType()),
- nextEntry.isArray());
+ str::stream() << "All operands of $setEquals must be arrays. " << (index + 1)
+ << "-th argument is of type: " << typeName(entry.getType()),
+ entry.isArray());
+ ValueUnorderedSet entrySet = valueComparator.makeUnorderedValueSet();
+ entrySet.insert(entry.getArray().begin(), entry.getArray().end());
+ return entrySet;
+ };
- if (i == 0) {
- lhs.insert(nextEntry.getArray().begin(), nextEntry.getArray().end());
- } else {
- ValueSet rhs = valueComparator.makeOrderedValueSet();
- rhs.insert(nextEntry.getArray().begin(), nextEntry.getArray().end());
- if (lhs.size() != rhs.size()) {
- return Value(false);
- }
+ size_t lhsIndex = _cachedConstant ? _cachedConstant->first : 0;
+ // The $setEquals expression has at least two children, so accessing the first child without
+ // check is fine.
+ ValueUnorderedSet lhs = _cachedConstant ? _cachedConstant->second : evaluateChild(0);
- if (!std::equal(lhs.begin(), lhs.end(), rhs.begin(), valueComparator.getEqualTo())) {
+ for (size_t i = 0; i < n; i++) {
+ if (i != lhsIndex) {
+ ValueUnorderedSet rhs = evaluateChild(i);
+ if (!setEqualsHelper(lhs, rhs, valueComparator)) {
return Value(false);
}
}
@@ -4723,6 +4741,35 @@ Value ExpressionSetEquals::evaluate(const Document& root, Variables* variables)
return Value(true);
}
+/**
+ * If there's a constant set in the input, we can construct a hash set for the constant once during
+ * optimize() and compare other sets against it, which reduces the runtime to construct the constant
+ * sets over and over.
+ */
+intrusive_ptr<Expression> ExpressionSetEquals::optimize() {
+ const size_t n = _children.size();
+ const ValueComparator& valueComparator = getExpressionContext()->getValueComparator();
+
+ for (size_t i = 0; i < n; i++) {
+ _children[i] = _children[i]->optimize();
+ if (ExpressionConstant* ec = dynamic_cast<ExpressionConstant*>(_children[i].get())) {
+ const Value nextEntry = ec->getValue();
+ uassert(5887502,
+ str::stream() << "All operands of $setEquals must be arrays. " << (i + 1)
+ << "-th argument is of type: " << typeName(nextEntry.getType()),
+ nextEntry.isArray());
+
+ if (!_cachedConstant) {
+ _cachedConstant = std::make_pair(i, valueComparator.makeUnorderedValueSet());
+ _cachedConstant->second.insert(nextEntry.getArray().begin(),
+ nextEntry.getArray().end());
+ }
+ }
+ }
+
+ return this;
+}
+
REGISTER_STABLE_EXPRESSION(setEquals, ExpressionSetEquals::parse);
const char* ExpressionSetEquals::getOpName() const {
return "$setEquals";
diff --git a/src/mongo/db/pipeline/expression.h b/src/mongo/db/pipeline/expression.h
index e2caabf6605..10ed0f23bb1 100644
--- a/src/mongo/db/pipeline/expression.h
+++ b/src/mongo/db/pipeline/expression.h
@@ -2754,7 +2754,8 @@ public:
ExpressionSetEquals(ExpressionContext* const expCtx, ExpressionVector&& children)
: ExpressionVariadic<ExpressionSetEquals>(expCtx, std::move(children)) {}
- Value evaluate(const Document& root, Variables* variables) const final;
+ boost::intrusive_ptr<Expression> optimize() override;
+ Value evaluate(const Document& root, Variables* variables) const override;
const char* getOpName() const final;
void validateArguments(const ExpressionVector& args) const final;
@@ -2765,6 +2766,11 @@ public:
void acceptVisitor(ExpressionConstVisitor* visitor) const final {
return visitor->visit(this);
}
+
+private:
+ // The first element in the pair represent the position on the constant in the '_children'
+ // array. The second element is the constant set.
+ boost::optional<std::pair<size_t, ValueUnorderedSet>> _cachedConstant;
};