summaryrefslogtreecommitdiff
path: root/src/mongo/db/matcher
diff options
context:
space:
mode:
authorAlexander Ignatyev <alexander.ignatyev@mongodb.com>2022-01-18 07:04:39 +0000
committerEvergreen Agent <no-reply@evergreen.mongodb.com>2022-01-18 07:49:55 +0000
commit271b0ed95bb065a4c46a4da4e0ddd0dcf2799543 (patch)
treede6ca7fe7db6ab2769c4ff41343c4fb2b13525c2 /src/mongo/db/matcher
parent42da3b4e6a22aacc93c4cbb646c1cd332e6d4bcd (diff)
downloadmongo-271b0ed95bb065a4c46a4da4e0ddd0dcf2799543.tar.gz
SERVER-61420 Create MatchExpression visitor to set paramId on tree nodes
Diffstat (limited to 'src/mongo/db/matcher')
-rw-r--r--src/mongo/db/matcher/SConscript2
-rw-r--r--src/mongo/db/matcher/expression.cpp9
-rw-r--r--src/mongo/db/matcher/expression.h7
-rw-r--r--src/mongo/db/matcher/expression_array.h15
-rw-r--r--src/mongo/db/matcher/expression_leaf.cpp3
-rw-r--r--src/mongo/db/matcher/expression_leaf.h131
-rw-r--r--src/mongo/db/matcher/expression_parameterization.cpp159
-rw-r--r--src/mongo/db/matcher/expression_parameterization.h170
-rw-r--r--src/mongo/db/matcher/expression_parameterization_test.cpp341
-rw-r--r--src/mongo/db/matcher/expression_type.h13
-rw-r--r--src/mongo/db/matcher/expression_where.cpp3
-rw-r--r--src/mongo/db/matcher/expression_where.h10
12 files changed, 850 insertions, 13 deletions
diff --git a/src/mongo/db/matcher/SConscript b/src/mongo/db/matcher/SConscript
index d6ae5c77848..3ddecb1524d 100644
--- a/src/mongo/db/matcher/SConscript
+++ b/src/mongo/db/matcher/SConscript
@@ -29,6 +29,7 @@ env.Library(
'expression_geo.cpp',
'expression_internal_bucket_geo_within.cpp',
'expression_leaf.cpp',
+ 'expression_parameterization.cpp',
'expression_parser.cpp',
'expression_text_base.cpp',
'expression_text_noop.cpp',
@@ -114,6 +115,7 @@ env.CppUnitTest(
'expression_internal_expr_eq_test.cpp',
'expression_leaf_test.cpp',
'expression_optimize_test.cpp',
+ 'expression_parameterization_test.cpp',
'expression_parser_array_test.cpp',
'expression_parser_geo_test.cpp',
'expression_parser_leaf_test.cpp',
diff --git a/src/mongo/db/matcher/expression.cpp b/src/mongo/db/matcher/expression.cpp
index 5f2324eff3f..7386e240198 100644
--- a/src/mongo/db/matcher/expression.cpp
+++ b/src/mongo/db/matcher/expression.cpp
@@ -31,6 +31,7 @@
#include "mongo/bson/bsonmisc.h"
#include "mongo/bson/bsonobj.h"
+#include "mongo/db/matcher/expression_parameterization.h"
#include "mongo/db/matcher/schema/json_schema_parser.h"
namespace mongo {
@@ -113,6 +114,14 @@ void MatchExpression::sortTree(MatchExpression* tree) {
}
}
+// static
+void MatchExpression::parameterize(MatchExpression* tree) {
+ MatchExpressionParameterizationVisitorContext context{};
+ MatchExpressionParameterizationVisitor visitor{&context};
+ MatchExpressionParameterizationWalker walker{&visitor};
+ tree_walker::walk<false, MatchExpression>(tree, &walker);
+}
+
std::string MatchExpression::toString() const {
return serialize().toString();
}
diff --git a/src/mongo/db/matcher/expression.h b/src/mongo/db/matcher/expression.h
index 0fbabc323eb..c87738faaf8 100644
--- a/src/mongo/db/matcher/expression.h
+++ b/src/mongo/db/matcher/expression.h
@@ -208,6 +208,7 @@ public:
using Iterator = MatchExpressionIterator<false>;
using ConstIterator = MatchExpressionIterator<true>;
+ using InputParamId = int64_t;
/**
* Tracks the information needed to generate a document validation error for a
@@ -321,6 +322,12 @@ public:
return tree;
}
+ /**
+ * Assigns an optional input parameter ID to each node which is eligible for
+ * auto-parameterization.
+ */
+ static void parameterize(MatchExpression* tree);
+
MatchExpression(MatchType type, clonable_ptr<ErrorAnnotation> annotation = nullptr);
virtual ~MatchExpression() {}
diff --git a/src/mongo/db/matcher/expression_array.h b/src/mongo/db/matcher/expression_array.h
index edcadfe2b5e..ec5e2d76f7f 100644
--- a/src/mongo/db/matcher/expression_array.h
+++ b/src/mongo/db/matcher/expression_array.h
@@ -188,12 +188,15 @@ public:
int size,
clonable_ptr<ErrorAnnotation> annotation = nullptr);
- virtual std::unique_ptr<MatchExpression> shallowClone() const {
+ std::unique_ptr<MatchExpression> shallowClone() const final {
std::unique_ptr<SizeMatchExpression> e =
std::make_unique<SizeMatchExpression>(path(), _size, _errorAnnotation);
if (getTag()) {
e->setTag(getTag()->clone());
}
+ if (getInputParamId()) {
+ e->setInputParamId(*getInputParamId());
+ }
return e;
}
@@ -229,11 +232,21 @@ public:
visitor->visit(this);
}
+ void setInputParamId(InputParamId paramId) {
+ _inputParamId = paramId;
+ }
+
+ boost::optional<InputParamId> getInputParamId() const {
+ return _inputParamId;
+ }
+
private:
virtual ExpressionOptimizerFunc getOptimizer() const final {
return [](std::unique_ptr<MatchExpression> expression) { return expression; };
}
int _size; // >= 0 real, < 0, nothing will match
+
+ boost::optional<InputParamId> _inputParamId;
};
} // namespace mongo
diff --git a/src/mongo/db/matcher/expression_leaf.cpp b/src/mongo/db/matcher/expression_leaf.cpp
index 9c97567aeca..31157666e92 100644
--- a/src/mongo/db/matcher/expression_leaf.cpp
+++ b/src/mongo/db/matcher/expression_leaf.cpp
@@ -442,6 +442,9 @@ std::unique_ptr<MatchExpression> InMatchExpression::shallowClone() const {
static_cast<RegexMatchExpression*>(regex->shallowClone().release()));
next->_regexes.push_back(std::move(clonedRegex));
}
+ if (getInputParamId()) {
+ next->setInputParamId(*getInputParamId());
+ }
return next;
}
diff --git a/src/mongo/db/matcher/expression_leaf.h b/src/mongo/db/matcher/expression_leaf.h
index e949f251c3c..fdcc14e10b0 100644
--- a/src/mongo/db/matcher/expression_leaf.h
+++ b/src/mongo/db/matcher/expression_leaf.h
@@ -183,6 +183,14 @@ public:
return _collator;
}
+ void setInputParamId(InputParamId paramId) {
+ _inputParamId = paramId;
+ }
+
+ boost::optional<InputParamId> getInputParamId() const {
+ return _inputParamId;
+ }
+
protected:
/**
* 'collator' must outlive the ComparisonMatchExpression and any clones made of it.
@@ -202,6 +210,8 @@ private:
ExpressionOptimizerFunc getOptimizer() const final {
return [](std::unique_ptr<MatchExpression> expression) { return expression; };
}
+
+ boost::optional<InputParamId> _inputParamId;
};
/**
@@ -255,13 +265,16 @@ public:
return kName;
}
- virtual std::unique_ptr<MatchExpression> shallowClone() const {
+ std::unique_ptr<MatchExpression> shallowClone() const final {
std::unique_ptr<ComparisonMatchExpression> e =
std::make_unique<EqualityMatchExpression>(path(), Value(getData()), _errorAnnotation);
if (getTag()) {
e->setTag(getTag()->clone());
}
e->setCollator(_collator);
+ if (getInputParamId()) {
+ e->setInputParamId(*getInputParamId());
+ }
return e;
}
@@ -291,13 +304,16 @@ public:
return kName;
}
- virtual std::unique_ptr<MatchExpression> shallowClone() const {
+ std::unique_ptr<MatchExpression> shallowClone() const final {
std::unique_ptr<ComparisonMatchExpression> e =
std::make_unique<LTEMatchExpression>(path(), _rhs, _errorAnnotation);
if (getTag()) {
e->setTag(getTag()->clone());
}
e->setCollator(_collator);
+ if (getInputParamId()) {
+ e->setInputParamId(*getInputParamId());
+ }
return e;
}
@@ -327,13 +343,16 @@ public:
return kName;
}
- virtual std::unique_ptr<MatchExpression> shallowClone() const {
+ std::unique_ptr<MatchExpression> shallowClone() const final {
std::unique_ptr<ComparisonMatchExpression> e =
std::make_unique<LTMatchExpression>(path(), _rhs, _errorAnnotation);
if (getTag()) {
e->setTag(getTag()->clone());
}
e->setCollator(_collator);
+ if (getInputParamId()) {
+ e->setInputParamId(*getInputParamId());
+ }
return e;
}
@@ -368,13 +387,16 @@ public:
return kName;
}
- virtual std::unique_ptr<MatchExpression> shallowClone() const {
+ std::unique_ptr<MatchExpression> shallowClone() const final {
std::unique_ptr<ComparisonMatchExpression> e =
std::make_unique<GTMatchExpression>(path(), _rhs, _errorAnnotation);
if (getTag()) {
e->setTag(getTag()->clone());
}
e->setCollator(_collator);
+ if (getInputParamId()) {
+ e->setInputParamId(*getInputParamId());
+ }
return e;
}
@@ -408,13 +430,16 @@ public:
return kName;
}
- virtual std::unique_ptr<MatchExpression> shallowClone() const {
+ std::unique_ptr<MatchExpression> shallowClone() const final {
std::unique_ptr<ComparisonMatchExpression> e =
std::make_unique<GTEMatchExpression>(path(), _rhs, _errorAnnotation);
if (getTag()) {
e->setTag(getTag()->clone());
}
e->setCollator(_collator);
+ if (getInputParamId()) {
+ e->setInputParamId(*getInputParamId());
+ }
return e;
}
@@ -449,12 +474,18 @@ public:
~RegexMatchExpression();
- virtual std::unique_ptr<MatchExpression> shallowClone() const {
+ std::unique_ptr<MatchExpression> shallowClone() const final {
std::unique_ptr<RegexMatchExpression> e =
std::make_unique<RegexMatchExpression>(path(), _regex, _flags, _errorAnnotation);
if (getTag()) {
e->setTag(getTag()->clone());
}
+ if (getSourceRegexInputParamId()) {
+ e->setSourceRegexInputParamId(*getSourceRegexInputParamId());
+ }
+ if (getCompiledRegexInputParamId()) {
+ e->setCompiledRegexInputParamId(*getCompiledRegexInputParamId());
+ }
return e;
}
@@ -485,6 +516,22 @@ public:
visitor->visit(this);
}
+ void setSourceRegexInputParamId(InputParamId paramId) {
+ _sourceRegexInputParamId = paramId;
+ }
+
+ void setCompiledRegexInputParamId(InputParamId paramId) {
+ _compiledRegexInputParamId = paramId;
+ }
+
+ boost::optional<InputParamId> getSourceRegexInputParamId() const {
+ return _sourceRegexInputParamId;
+ }
+
+ boost::optional<InputParamId> getCompiledRegexInputParamId() const {
+ return _compiledRegexInputParamId;
+ }
+
private:
ExpressionOptimizerFunc getOptimizer() const final {
return [](std::unique_ptr<MatchExpression> expression) { return expression; };
@@ -495,6 +542,9 @@ private:
std::string _regex;
std::string _flags;
std::unique_ptr<pcrecpp::RE> _re;
+
+ boost::optional<InputParamId> _sourceRegexInputParamId;
+ boost::optional<InputParamId> _compiledRegexInputParamId;
};
class ModMatchExpression : public LeafMatchExpression {
@@ -504,12 +554,18 @@ public:
long long remainder,
clonable_ptr<ErrorAnnotation> annotation = nullptr);
- virtual std::unique_ptr<MatchExpression> shallowClone() const {
+ std::unique_ptr<MatchExpression> shallowClone() const final {
std::unique_ptr<ModMatchExpression> m =
std::make_unique<ModMatchExpression>(path(), _divisor, _remainder, _errorAnnotation);
if (getTag()) {
m->setTag(getTag()->clone());
}
+ if (getDivisorInputParamId()) {
+ m->setDivisorInputParamId(*getDivisorInputParamId());
+ }
+ if (getRemainderInputParamId()) {
+ m->setRemainderInputParamId(*getRemainderInputParamId());
+ }
return m;
}
@@ -536,6 +592,22 @@ public:
visitor->visit(this);
}
+ void setDivisorInputParamId(InputParamId paramId) {
+ _divisorInputParamId = paramId;
+ }
+
+ void setRemainderInputParamId(InputParamId paramId) {
+ _remainderInputParamId = paramId;
+ }
+
+ boost::optional<InputParamId> getDivisorInputParamId() const {
+ return _divisorInputParamId;
+ }
+
+ boost::optional<InputParamId> getRemainderInputParamId() const {
+ return _remainderInputParamId;
+ }
+
private:
ExpressionOptimizerFunc getOptimizer() const final {
return [](std::unique_ptr<MatchExpression> expression) { return expression; };
@@ -543,6 +615,9 @@ private:
long long _divisor;
long long _remainder;
+
+ boost::optional<InputParamId> _divisorInputParamId;
+ boost::optional<InputParamId> _remainderInputParamId;
};
class ExistsMatchExpression : public LeafMatchExpression {
@@ -588,7 +663,7 @@ class InMatchExpression : public LeafMatchExpression {
public:
explicit InMatchExpression(StringData path, clonable_ptr<ErrorAnnotation> annotation = nullptr);
- virtual std::unique_ptr<MatchExpression> shallowClone() const;
+ std::unique_ptr<MatchExpression> shallowClone() const final;
bool matchesSingleElement(const BSONElement&, MatchDetails* details = nullptr) const final;
@@ -639,6 +714,14 @@ public:
visitor->visit(this);
}
+ void setInputParamId(InputParamId paramId) {
+ _inputParamId = paramId;
+ }
+
+ boost::optional<InputParamId> getInputParamId() const {
+ return _inputParamId;
+ }
+
private:
ExpressionOptimizerFunc getOptimizer() const final;
@@ -675,6 +758,8 @@ private:
// When this $in is generated internally, e.g. via a rewrite, this is where we store the
// data of the corresponding equality elements.
BSONObj _equalityStorage;
+
+ boost::optional<InputParamId> _inputParamId;
};
/**
@@ -723,6 +808,14 @@ public:
std::string name() const;
+ void setInputParamId(InputParamId paramId) {
+ _inputParamId = paramId;
+ }
+
+ boost::optional<InputParamId> getInputParamId() const {
+ return _inputParamId;
+ }
+
private:
ExpressionOptimizerFunc getOptimizer() const final {
return [](std::unique_ptr<MatchExpression> expression) { return expression; };
@@ -754,6 +847,8 @@ private:
// Used to perform bit tests against numbers using a single bitwise operation.
uint64_t _bitMask = 0;
+
+ boost::optional<InputParamId> _inputParamId;
};
class BitsAllSetMatchExpression : public BitTestMatchExpression {
@@ -775,13 +870,16 @@ public:
: BitTestMatchExpression(
BITS_ALL_SET, path, bitMaskBinary, bitMaskLen, std::move(annotation)) {}
- virtual std::unique_ptr<MatchExpression> shallowClone() const {
+ std::unique_ptr<MatchExpression> shallowClone() const final {
std::unique_ptr<BitTestMatchExpression> bitTestMatchExpression =
std::make_unique<BitsAllSetMatchExpression>(
path(), getBitPositions(), _errorAnnotation);
if (getTag()) {
bitTestMatchExpression->setTag(getTag()->clone());
}
+ if (getInputParamId()) {
+ bitTestMatchExpression->setInputParamId(*getInputParamId());
+ }
return bitTestMatchExpression;
}
@@ -813,13 +911,16 @@ public:
: BitTestMatchExpression(
BITS_ALL_CLEAR, path, bitMaskBinary, bitMaskLen, std::move(annotation)) {}
- virtual std::unique_ptr<MatchExpression> shallowClone() const {
+ std::unique_ptr<MatchExpression> shallowClone() const final {
std::unique_ptr<BitTestMatchExpression> bitTestMatchExpression =
std::make_unique<BitsAllClearMatchExpression>(
path(), getBitPositions(), _errorAnnotation);
if (getTag()) {
bitTestMatchExpression->setTag(getTag()->clone());
}
+ if (getInputParamId()) {
+ bitTestMatchExpression->setInputParamId(*getInputParamId());
+ }
return bitTestMatchExpression;
}
@@ -851,13 +952,16 @@ public:
: BitTestMatchExpression(
BITS_ANY_SET, path, bitMaskBinary, bitMaskLen, std::move(annotation)) {}
- virtual std::unique_ptr<MatchExpression> shallowClone() const {
+ std::unique_ptr<MatchExpression> shallowClone() const final {
std::unique_ptr<BitTestMatchExpression> bitTestMatchExpression =
std::make_unique<BitsAnySetMatchExpression>(
path(), getBitPositions(), _errorAnnotation);
if (getTag()) {
bitTestMatchExpression->setTag(getTag()->clone());
}
+ if (getInputParamId()) {
+ bitTestMatchExpression->setInputParamId(*getInputParamId());
+ }
return bitTestMatchExpression;
}
@@ -889,13 +993,16 @@ public:
: BitTestMatchExpression(
BITS_ANY_CLEAR, path, bitMaskBinary, bitMaskLen, std::move(annotation)) {}
- virtual std::unique_ptr<MatchExpression> shallowClone() const {
+ std::unique_ptr<MatchExpression> shallowClone() const final {
std::unique_ptr<BitTestMatchExpression> bitTestMatchExpression =
std::make_unique<BitsAnyClearMatchExpression>(
path(), getBitPositions(), _errorAnnotation);
if (getTag()) {
bitTestMatchExpression->setTag(getTag()->clone());
}
+ if (getInputParamId()) {
+ bitTestMatchExpression->setInputParamId(*getInputParamId());
+ }
return bitTestMatchExpression;
}
diff --git a/src/mongo/db/matcher/expression_parameterization.cpp b/src/mongo/db/matcher/expression_parameterization.cpp
new file mode 100644
index 00000000000..de761274c8c
--- /dev/null
+++ b/src/mongo/db/matcher/expression_parameterization.cpp
@@ -0,0 +1,159 @@
+/**
+ * Copyright (C) 2022-present MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the Server Side Public License, version 1,
+ * as published by MongoDB, Inc.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * Server Side Public License for more details.
+ *
+ * You should have received a copy of the Server Side Public License
+ * along with this program. If not, see
+ * <http://www.mongodb.com/licensing/server-side-public-license>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the Server Side Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#include "mongo/db/matcher/expression_parameterization.h"
+
+#include <cmath>
+
+namespace mongo {
+void MatchExpressionParameterizationVisitor::visit(BitsAllClearMatchExpression* expr) {
+ expr->setInputParamId(_context->nextInputParamId());
+}
+
+void MatchExpressionParameterizationVisitor::visit(BitsAllSetMatchExpression* expr) {
+ expr->setInputParamId(_context->nextInputParamId());
+}
+
+void MatchExpressionParameterizationVisitor::visit(BitsAnyClearMatchExpression* expr) {
+ expr->setInputParamId(_context->nextInputParamId());
+}
+
+void MatchExpressionParameterizationVisitor::visit(BitsAnySetMatchExpression* expr) {
+ expr->setInputParamId(_context->nextInputParamId());
+}
+
+void MatchExpressionParameterizationVisitor::visit(EqualityMatchExpression* expr) {
+ visitComparisonMatchExpression(expr);
+}
+
+void MatchExpressionParameterizationVisitor::visit(GTEMatchExpression* expr) {
+ visitComparisonMatchExpression(expr);
+}
+
+void MatchExpressionParameterizationVisitor::visit(GTMatchExpression* expr) {
+ visitComparisonMatchExpression(expr);
+}
+
+void MatchExpressionParameterizationVisitor::visit(LTEMatchExpression* expr) {
+ visitComparisonMatchExpression(expr);
+}
+
+void MatchExpressionParameterizationVisitor::visit(LTMatchExpression* expr) {
+ visitComparisonMatchExpression(expr);
+}
+
+void MatchExpressionParameterizationVisitor::visit(ModMatchExpression* expr) {
+ expr->setDivisorInputParamId(_context->nextInputParamId());
+ expr->setRemainderInputParamId(_context->nextInputParamId());
+}
+
+void MatchExpressionParameterizationVisitor::visit(RegexMatchExpression* expr) {
+ expr->setSourceRegexInputParamId(_context->nextInputParamId());
+ expr->setCompiledRegexInputParamId(_context->nextInputParamId());
+}
+
+void MatchExpressionParameterizationVisitor::visit(SizeMatchExpression* expr) {
+ expr->setInputParamId(_context->nextInputParamId());
+}
+
+void MatchExpressionParameterizationVisitor::visit(WhereMatchExpression* expr) {
+ expr->setInputParamId(_context->nextInputParamId());
+}
+
+void MatchExpressionParameterizationVisitor::visitComparisonMatchExpression(
+ ComparisonMatchExpressionBase* expr) {
+ auto type = expr->getData().type();
+ switch (type) {
+ case BSONType::MinKey:
+ case BSONType::EOO:
+ case BSONType::jstNULL:
+ case BSONType::Array:
+ case BSONType::DBRef:
+ case BSONType::MaxKey:
+ case BSONType::Undefined:
+ // ignore such values
+ break;
+
+ case BSONType::String:
+ case BSONType::Object:
+ case BSONType::BinData:
+ case BSONType::jstOID:
+ case BSONType::Bool:
+ case BSONType::RegEx:
+ case BSONType::Date:
+ case BSONType::Code:
+ case BSONType::Symbol:
+ case BSONType::CodeWScope:
+ case BSONType::NumberInt:
+ case BSONType::bsonTimestamp:
+ case BSONType::NumberLong:
+ expr->setInputParamId(_context->nextInputParamId());
+ break;
+ case BSONType::NumberDouble:
+ if (!std::isnan(expr->getData().numberDouble())) {
+ expr->setInputParamId(_context->nextInputParamId());
+ }
+ break;
+ case BSONType::NumberDecimal:
+ if (!expr->getData().numberDecimal().isNaN()) {
+ expr->setInputParamId(_context->nextInputParamId());
+ }
+ break;
+ }
+}
+
+void MatchExpressionParameterizationVisitor::visit(InMatchExpression* expr) {
+ // We don't set inputParamId if a InMatchExpression contains a regex.
+ if (!expr->getRegexes().empty()) {
+ return;
+ }
+
+ for (auto&& equality : expr->getEqualities()) {
+ switch (equality.type()) {
+ case BSONType::jstNULL:
+ case BSONType::Array:
+ // We don't set inputParamId if a InMatchExpression contains one of the values
+ // above.
+ return;
+ case BSONType::Undefined:
+ tasserted(6142000, "Unexpected type in $in expression");
+ default:
+ break;
+ };
+ }
+
+ expr->setInputParamId(_context->nextInputParamId());
+}
+
+void MatchExpressionParameterizationVisitor::visit(TypeMatchExpression* expr) {
+ if (!expr->typeSet().hasType(BSONType::Array)) {
+ expr->setInputParamId(_context->nextInputParamId());
+ }
+}
+} // namespace mongo
diff --git a/src/mongo/db/matcher/expression_parameterization.h b/src/mongo/db/matcher/expression_parameterization.h
new file mode 100644
index 00000000000..c88e3ab01ee
--- /dev/null
+++ b/src/mongo/db/matcher/expression_parameterization.h
@@ -0,0 +1,170 @@
+/**
+ * Copyright (C) 2022-present MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the Server Side Public License, version 1,
+ * as published by MongoDB, Inc.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * Server Side Public License for more details.
+ *
+ * You should have received a copy of the Server Side Public License
+ * along with this program. If not, see
+ * <http://www.mongodb.com/licensing/server-side-public-license>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the Server Side Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#pragma once
+
+#include <cstdint>
+
+#include "mongo/db/matcher/expression_always_boolean.h"
+#include "mongo/db/matcher/expression_array.h"
+#include "mongo/db/matcher/expression_leaf.h"
+#include "mongo/db/matcher/expression_type.h"
+#include "mongo/db/matcher/expression_visitor.h"
+#include "mongo/db/matcher/expression_where.h"
+
+namespace mongo {
+/**
+ * A context to track assigned input parameter IDs for auto-parameterization.
+ */
+class MatchExpressionParameterizationVisitorContext {
+public:
+ using InputParamId = MatchExpression::InputParamId;
+
+ virtual InputParamId nextInputParamId() {
+ return _inputParamIdCounter++;
+ }
+
+private:
+ InputParamId _inputParamIdCounter{0};
+};
+
+/**
+ * An implementation of a MatchExpression visitor which assigns an optional input parameter ID to
+ * each node which is eligible for auto-parameterization:
+ * - BitsAllClearMatchExpression
+ * - BitsAllSetMatchExpression
+ * - BitsAnyClearMatchExpression
+ * - BitsAnySetMatchExpression
+ * - Comparison expressions, unless compared against MinKey, MaxKey, null or NaN value or array
+ * - EqualityMatchExpression
+ * - GTEMatchExpression
+ * - GTMatchExpression
+ * - LTEMatchExpression
+ * - LTMatchExpression
+ * - InMatchExpression, unless it contains an array, null or regexp value.
+ * - ModMatchExpression (two parameter IDs for the divider and reminder)
+ * - RegexMatchExpression (two parameter IDs for the compiled regex and raw value)
+ * - SizeMatchExpression
+ * - TypeMatchExpression, unless type value is Array
+ * - WhereMatchExpression
+ */
+class MatchExpressionParameterizationVisitor final : public MatchExpressionMutableVisitor {
+public:
+ MatchExpressionParameterizationVisitor(MatchExpressionParameterizationVisitorContext* context)
+ : _context{context} {
+ invariant(_context);
+ }
+
+ void visit(AlwaysFalseMatchExpression* expr) final {}
+ void visit(AlwaysTrueMatchExpression* expr) final {}
+ void visit(AndMatchExpression* expr) final {}
+ void visit(BitsAllClearMatchExpression* expr) final;
+ void visit(BitsAllSetMatchExpression* expr) final;
+ void visit(BitsAnyClearMatchExpression* expr) final;
+ void visit(BitsAnySetMatchExpression* expr) final;
+ void visit(ElemMatchObjectMatchExpression* matchExpr) final {}
+ void visit(ElemMatchValueMatchExpression* matchExpr) final {}
+ void visit(EqualityMatchExpression* expr) final;
+ void visit(ExistsMatchExpression* expr) final {}
+ void visit(ExprMatchExpression* expr) final {}
+ void visit(GTEMatchExpression* expr) final;
+ void visit(GTMatchExpression* expr) final;
+ void visit(GeoMatchExpression* expr) final {}
+ void visit(GeoNearMatchExpression* expr) final {}
+ void visit(InMatchExpression* expr) final;
+ void visit(InternalBucketGeoWithinMatchExpression* expr) final {}
+ void visit(InternalExprEqMatchExpression* expr) final {}
+ void visit(InternalExprGTMatchExpression* expr) final {}
+ void visit(InternalExprGTEMatchExpression* expr) final {}
+ void visit(InternalExprLTMatchExpression* expr) final {}
+ void visit(InternalExprLTEMatchExpression* expr) final {}
+ void visit(InternalSchemaAllElemMatchFromIndexMatchExpression* expr) final {}
+ void visit(InternalSchemaAllowedPropertiesMatchExpression* expr) final {}
+ void visit(InternalSchemaBinDataEncryptedTypeExpression* expr) final {}
+ void visit(InternalSchemaBinDataSubTypeExpression* expr) final {}
+ void visit(InternalSchemaCondMatchExpression* expr) final {}
+ void visit(InternalSchemaEqMatchExpression* expr) final {}
+ void visit(InternalSchemaFmodMatchExpression* expr) final {}
+ void visit(InternalSchemaMatchArrayIndexMatchExpression* expr) final {}
+ void visit(InternalSchemaMaxItemsMatchExpression* expr) final {}
+ void visit(InternalSchemaMaxLengthMatchExpression* expr) final {}
+ void visit(InternalSchemaMaxPropertiesMatchExpression* expr) final {}
+ void visit(InternalSchemaMinItemsMatchExpression* expr) final {}
+ void visit(InternalSchemaMinLengthMatchExpression* expr) final {}
+ void visit(InternalSchemaMinPropertiesMatchExpression* expr) final {}
+ void visit(InternalSchemaObjectMatchExpression* expr) final {}
+ void visit(InternalSchemaRootDocEqMatchExpression* expr) final {}
+ void visit(InternalSchemaTypeExpression* expr) final {}
+ void visit(InternalSchemaUniqueItemsMatchExpression* expr) final {}
+ void visit(InternalSchemaXorMatchExpression* expr) final {}
+ void visit(LTEMatchExpression* expr) final;
+ void visit(LTMatchExpression* expr) final;
+ void visit(ModMatchExpression* expr) final;
+ void visit(NorMatchExpression* expr) final {}
+ void visit(NotMatchExpression* expr) final {}
+ void visit(OrMatchExpression* expr) final {}
+ void visit(RegexMatchExpression* expr) final;
+ void visit(SizeMatchExpression* expr) final;
+ void visit(TextMatchExpression* expr) final {}
+ void visit(TextNoOpMatchExpression* expr) final {}
+ void visit(TwoDPtInAnnulusExpression* expr) final {}
+ void visit(TypeMatchExpression* expr) final;
+ void visit(WhereMatchExpression* expr) final;
+ void visit(WhereNoOpMatchExpression* expr) final {}
+
+private:
+ void visitComparisonMatchExpression(ComparisonMatchExpressionBase* expr);
+
+ MatchExpressionParameterizationVisitorContext* _context;
+};
+
+/**
+ * A match expression tree walker compatible with tree_walker::walk() to be used with
+ * MatchExpressionParameterizationVisitor.
+ */
+class MatchExpressionParameterizationWalker {
+public:
+ MatchExpressionParameterizationWalker(MatchExpressionParameterizationVisitor* visitor)
+ : _visitor{visitor} {
+ invariant(_visitor);
+ }
+
+ void preVisit(MatchExpression* expr) {
+ expr->acceptVisitor(_visitor);
+ }
+
+ void postVisit(MatchExpression* expr) {}
+
+ void inVisit(long count, MatchExpression* expr) {}
+
+private:
+ MatchExpressionParameterizationVisitor* _visitor;
+};
+
+} // namespace mongo
diff --git a/src/mongo/db/matcher/expression_parameterization_test.cpp b/src/mongo/db/matcher/expression_parameterization_test.cpp
new file mode 100644
index 00000000000..37874a421d8
--- /dev/null
+++ b/src/mongo/db/matcher/expression_parameterization_test.cpp
@@ -0,0 +1,341 @@
+/**
+ * Copyright (C) 2022-present MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the Server Side Public License, version 1,
+ * as published by MongoDB, Inc.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * Server Side Public License for more details.
+ *
+ * You should have received a copy of the Server Side Public License
+ * along with this program. If not, see
+ * <http://www.mongodb.com/licensing/server-side-public-license>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the Server Side Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#include "mongo/db/matcher/expression_parameterization.h"
+#include "mongo/db/operation_context.h"
+#include "mongo/db/pipeline/expression_context_for_test.h"
+#include "mongo/db/query/query_planner_params.h"
+#include "mongo/unittest/unittest.h"
+
+namespace mongo {
+
+namespace {
+struct MatchExpressionParameterizationTestVisitorContext
+ : public MatchExpressionParameterizationVisitorContext {
+ InputParamId nextInputParamId() override {
+ auto paramId = MatchExpressionParameterizationVisitorContext::nextInputParamId();
+ inputParamIds.insert(paramId);
+ return paramId;
+ }
+
+ std::set<MatchExpression::InputParamId> inputParamIds{};
+};
+
+void walkExpression(MatchExpressionParameterizationVisitorContext* context,
+ MatchExpression* expression) {
+ MatchExpressionParameterizationVisitor visitor{context};
+ MatchExpressionParameterizationWalker walker{&visitor};
+ tree_walker::walk<false, MatchExpression>(expression, &walker);
+}
+} // namespace
+
+TEST(MatchExpressionParameterizationVisitor, AlwaysFalseMatchExpressionSetsNoParamIds) {
+ AlwaysFalseMatchExpression expr{};
+ MatchExpressionParameterizationTestVisitorContext context{};
+ MatchExpressionParameterizationVisitor visitor{&context};
+ expr.acceptVisitor(&visitor);
+ ASSERT_EQ(0, context.inputParamIds.size());
+}
+
+TEST(MatchExpressionParameterizationVisitor, AlwaysTrueMatchExpressionSetsNoParamIds) {
+ AlwaysTrueMatchExpression expr{};
+ MatchExpressionParameterizationTestVisitorContext context{};
+ MatchExpressionParameterizationVisitor visitor{&context};
+ expr.acceptVisitor(&visitor);
+ ASSERT_EQ(0, context.inputParamIds.size());
+}
+
+TEST(MatchExpressionParameterizationVisitor, BitsAllClearMatchExpressionSetsOneParamId) {
+ std::vector<uint32_t> bitPositions;
+ BitsAllClearMatchExpression expr{"a", bitPositions};
+ MatchExpressionParameterizationTestVisitorContext context{};
+ MatchExpressionParameterizationVisitor visitor{&context};
+ expr.acceptVisitor(&visitor);
+ ASSERT_EQ(1, context.inputParamIds.size());
+}
+
+TEST(MatchExpressionParameterizationVisitor, BitsAllSetMatchExpressionSetsOneParamId) {
+ std::vector<uint32_t> bitPositions;
+ BitsAllSetMatchExpression expr{"a", bitPositions};
+ MatchExpressionParameterizationTestVisitorContext context{};
+ MatchExpressionParameterizationVisitor visitor{&context};
+ expr.acceptVisitor(&visitor);
+ ASSERT_EQ(1, context.inputParamIds.size());
+}
+
+TEST(MatchExpressionParameterizationVisitor, BitsAnyClearMatchExpressionSetsOneParamId) {
+ std::vector<uint32_t> bitPositions{0, 1, 8};
+ BitsAnyClearMatchExpression expr{"a", bitPositions};
+ MatchExpressionParameterizationTestVisitorContext context{};
+ MatchExpressionParameterizationVisitor visitor{&context};
+ expr.acceptVisitor(&visitor);
+ ASSERT_EQ(1, context.inputParamIds.size());
+}
+
+TEST(MatchExpressionParameterizationVisitor, BitsAnySetMatchExpressionSetsOneParamId) {
+ std::vector<uint32_t> bitPositions{0, 1, 8};
+ BitsAnySetMatchExpression expr{"a", bitPositions};
+ MatchExpressionParameterizationTestVisitorContext context{};
+ MatchExpressionParameterizationVisitor visitor{&context};
+ expr.acceptVisitor(&visitor);
+ ASSERT_EQ(1, context.inputParamIds.size());
+}
+
+TEST(MatchExpressionParameterizationVisitor,
+ EqualityMatchExpressionWithScalarParameterSetsOneParamId) {
+ BSONObj query = BSON("a" << 5);
+ EqualityMatchExpression eq("a", query["a"]);
+ MatchExpressionParameterizationTestVisitorContext context{};
+ MatchExpressionParameterizationVisitor visitor{&context};
+ eq.acceptVisitor(&visitor);
+ ASSERT_EQ(1, context.inputParamIds.size());
+}
+
+TEST(MatchExpressionParameterizationVisitor, EqualityMatchExpressionWithNullSetsNoParamIds) {
+ BSONObj query = BSON("a" << BSONNULL);
+ EqualityMatchExpression eq{"a", query["a"]};
+ MatchExpressionParameterizationTestVisitorContext context{};
+ MatchExpressionParameterizationVisitor visitor{&context};
+ eq.acceptVisitor(&visitor);
+ ASSERT_EQ(0, context.inputParamIds.size());
+}
+
+TEST(MatchExpressionParameterizationVisitor, EqualityMatchExpressionWithArraySetsNoParamIds) {
+ BSONObj query = BSON("a" << BSON_ARRAY(1 << 2));
+ EqualityMatchExpression eq{"a", query["a"]};
+ MatchExpressionParameterizationTestVisitorContext context{};
+ MatchExpressionParameterizationVisitor visitor{&context};
+ eq.acceptVisitor(&visitor);
+ ASSERT_EQ(0, context.inputParamIds.size());
+}
+
+TEST(MatchExpressionParameterizationVisitor, EqualityMatchExpressionWithMinKeySetsNoParamIds) {
+ BSONObj query = BSON("a" << MINKEY);
+ EqualityMatchExpression eq{"a", query["a"]};
+ MatchExpressionParameterizationTestVisitorContext context{};
+ MatchExpressionParameterizationVisitor visitor{&context};
+ eq.acceptVisitor(&visitor);
+ ASSERT_EQ(0, context.inputParamIds.size());
+}
+
+TEST(MatchExpressionParameterizationVisitor, EqualityMatchExpressionWithMaxKeySetsNoParamIds) {
+ BSONObj query = BSON("a" << MAXKEY);
+ EqualityMatchExpression eq{"a", query["a"]};
+ MatchExpressionParameterizationTestVisitorContext context{};
+ MatchExpressionParameterizationVisitor visitor{&context};
+ eq.acceptVisitor(&visitor);
+ ASSERT_EQ(0, context.inputParamIds.size());
+}
+
+TEST(MatchExpressionParameterizationVisitor, EqualityMatchExpressionWithUndefinedThrows) {
+ BSONObj query = BSON("a" << BSONUndefined);
+ ASSERT_THROWS((EqualityMatchExpression{"a", query["a"]}), DBException);
+}
+
+TEST(MatchExpressionParameterizationVisitor, GTEMatchExpressionWithScalarParameterSetsOneParamId) {
+ BSONObj query = BSON("$gte" << 5);
+ GTEMatchExpression expr{"a", query["$gte"]};
+ MatchExpressionParameterizationTestVisitorContext context{};
+ MatchExpressionParameterizationVisitor visitor{&context};
+ expr.acceptVisitor(&visitor);
+ ASSERT_EQ(1, context.inputParamIds.size());
+}
+
+TEST(MatchExpressionParameterizationVisitor, GTEMatchExpressionWithUndefinedThrows) {
+ BSONObj query = BSON("a" << BSONUndefined);
+ ASSERT_THROWS((EqualityMatchExpression{"a", query["a"]}), DBException);
+}
+
+TEST(MatchExpressionParameterizationVisitor, GTMatchExpressionWithScalarParameterSetsOneParamId) {
+ BSONObj query = BSON("$gte" << 5);
+ GTMatchExpression expr{"a", query["$gte"]};
+ MatchExpressionParameterizationTestVisitorContext context{};
+ MatchExpressionParameterizationVisitor visitor{&context};
+ expr.acceptVisitor(&visitor);
+ ASSERT_EQ(1, context.inputParamIds.size());
+}
+
+TEST(MatchExpressionParameterizationVisitor, LTEMatchExpressionWithScalarParameterSetsOneParamId) {
+ BSONObj query = BSON("$lte" << 5);
+ LTEMatchExpression expr("a", query["$lte"]);
+ MatchExpressionParameterizationTestVisitorContext context{};
+ MatchExpressionParameterizationVisitor visitor{&context};
+ expr.acceptVisitor(&visitor);
+ ASSERT_EQ(1, context.inputParamIds.size());
+}
+
+TEST(MatchExpressionParameterizationVisitor, LTMatchExpressionWithScalarParameterSetsOneParamId) {
+ BSONObj query = BSON("$lt" << 5);
+ LTMatchExpression expr{"a", query["$lt"]};
+ MatchExpressionParameterizationTestVisitorContext context{};
+ MatchExpressionParameterizationVisitor visitor{&context};
+ expr.acceptVisitor(&visitor);
+ ASSERT_EQ(1, context.inputParamIds.size());
+}
+
+TEST(MatchExpressionParameterizationVisitor, ComparisonMatchExpressionsWithNaNSetsNoParamIds) {
+ std::vector<std::unique_ptr<MatchExpression>> expressions;
+
+ BSONObj doubleNaN = BSON("$lt" << std::numeric_limits<double>::quiet_NaN());
+ expressions.emplace_back(std::make_unique<LTMatchExpression>("a", doubleNaN["$lt"]));
+
+ BSONObj decimalNegativeNaN = BSON("$gt" << Decimal128::kNegativeNaN);
+ expressions.emplace_back(std::make_unique<GTMatchExpression>("b", decimalNegativeNaN["$gt"]));
+
+ BSONObj decimalPositiveNaN = BSON("c" << Decimal128::kPositiveNaN);
+ expressions.emplace_back(
+ std::make_unique<EqualityMatchExpression>("c", decimalPositiveNaN["c"]));
+
+ OrMatchExpression expr{std::move(expressions)};
+
+ MatchExpressionParameterizationTestVisitorContext context{};
+ walkExpression(&context, &expr);
+
+ ASSERT_EQ(0, context.inputParamIds.size());
+}
+
+TEST(MatchExpressionParameterizationVisitor, InMatchExpressionWithScalarsSetsOneParamId) {
+ BSONObj operand = BSON_ARRAY(1 << "r" << true << 1.1);
+ InMatchExpression expr{"a"};
+ std::vector<BSONElement> equalities{operand[0], operand[1], operand[2], operand[3]};
+ ASSERT_OK(expr.setEqualities(std::move(equalities)));
+
+ MatchExpressionParameterizationTestVisitorContext context{};
+ MatchExpressionParameterizationVisitor visitor{&context};
+ expr.acceptVisitor(&visitor);
+ ASSERT_EQ(1, context.inputParamIds.size());
+}
+
+TEST(MatchExpressionParameterizationVisitor, InMatchExpressionWithNullSetsNoParamIds) {
+ BSONObj operand = BSON_ARRAY(1 << "r" << true << BSONNULL);
+ InMatchExpression expr{"a"};
+ std::vector<BSONElement> equalities{operand[0], operand[1], operand[2], operand[3]};
+ ASSERT_OK(expr.setEqualities(std::move(equalities)));
+
+ MatchExpressionParameterizationTestVisitorContext context{};
+ MatchExpressionParameterizationVisitor visitor{&context};
+ expr.acceptVisitor(&visitor);
+ ASSERT_EQ(0, context.inputParamIds.size());
+}
+
+TEST(MatchExpressionParameterizationVisitor, InMatchExpressionWithRegexSetsNoParamIds) {
+ BSONObj query = BSON("a" << BSON("$in" << BSON_ARRAY(BSONRegEx("/^regex/i"))));
+
+ boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest());
+ StatusWithMatchExpression result = MatchExpressionParser::parse(query, expCtx);
+ ASSERT_TRUE(result.isOK());
+
+ MatchExpressionParameterizationTestVisitorContext context{};
+ walkExpression(&context, result.getValue().get());
+ ASSERT_EQ(0, context.nextInputParamId());
+}
+
+TEST(MatchExpressionParameterizationVisitor, ModMatchExpressionSetsTwoParamIds) {
+ ModMatchExpression expr{"a", 1, 2};
+
+ MatchExpressionParameterizationTestVisitorContext context{};
+ MatchExpressionParameterizationVisitor visitor{&context};
+ expr.acceptVisitor(&visitor);
+ ASSERT_EQ(2, context.inputParamIds.size());
+}
+
+TEST(MatchExpressionParameterizationVisitor, RegexMatchExpressionSetsTwoParamIds) {
+ RegexMatchExpression expr{"", "b", ""};
+
+ MatchExpressionParameterizationTestVisitorContext context{};
+ MatchExpressionParameterizationVisitor visitor{&context};
+ expr.acceptVisitor(&visitor);
+ ASSERT_EQ(2, context.inputParamIds.size());
+ ASSERT_EQ(2, context.nextInputParamId());
+}
+
+TEST(MatchExpressionParameterizationVisitor, SizeMatchExpressionSetsOneParamId) {
+ SizeMatchExpression expr{"a", 2};
+
+ MatchExpressionParameterizationTestVisitorContext context{};
+ MatchExpressionParameterizationVisitor visitor{&context};
+ expr.acceptVisitor(&visitor);
+ ASSERT_EQ(1, context.inputParamIds.size());
+ ASSERT_EQ(1, context.nextInputParamId());
+}
+
+TEST(MatchExpressionParameterizationVisitor, TypeMatchExpressionWithStringSetsOneParamId) {
+ TypeMatchExpression expr{"a", BSONType::String};
+
+ MatchExpressionParameterizationTestVisitorContext context{};
+ MatchExpressionParameterizationVisitor visitor{&context};
+ expr.acceptVisitor(&visitor);
+ ASSERT_EQ(1, context.inputParamIds.size());
+}
+
+TEST(MatchExpressionParameterizationVisitor, TypeMatchExpressionWithArraySetsNoParamIds) {
+ TypeMatchExpression expr{"a", BSONType::Array};
+
+ MatchExpressionParameterizationTestVisitorContext context{};
+ MatchExpressionParameterizationVisitor visitor{&context};
+ expr.acceptVisitor(&visitor);
+ ASSERT_EQ(0, context.inputParamIds.size());
+}
+
+TEST(MatchExpressionParameterizationVisitor, ExprMatchExpressionSetsNoParamsIds) {
+ BSONObj query = BSON("$expr" << BSON("$gte" << BSON_ARRAY("$a"
+ << "$b")));
+
+ boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest());
+ StatusWithMatchExpression result = MatchExpressionParser::parse(query, expCtx);
+ ASSERT_TRUE(result.isOK());
+
+ MatchExpressionParameterizationTestVisitorContext context{};
+ walkExpression(&context, result.getValue().get());
+ ASSERT_EQ(0, context.nextInputParamId());
+}
+
+TEST(MatchExpressionParameterizationVisitor,
+ AutoParametrizationWalkerSetsCorrectNumberOfParamsIds) {
+ BSONObj equalityExpr = BSON("x" << 1);
+ BSONObj gtExpr = BSON("y" << BSON("$gt" << 2));
+ BSONObj inExpr = BSON("$in" << BSON_ARRAY("a"
+ << "b"
+ << "c"));
+ BSONObj regexExpr = BSON("m" << BSONRegEx("/^regex/i"));
+ BSONObj sizeExpr = BSON("n" << BSON("$size" << 1));
+
+ BSONObj query = BSON("$or" << BSON_ARRAY(equalityExpr
+ << gtExpr << BSON("z" << inExpr)
+ << BSON("$and" << BSON_ARRAY(regexExpr << sizeExpr))));
+
+ boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest());
+ StatusWithMatchExpression result = MatchExpressionParser::parse(query, expCtx);
+ ASSERT_TRUE(result.isOK());
+
+ MatchExpressionParameterizationTestVisitorContext context{};
+ walkExpression(&context, result.getValue().get());
+ ASSERT_EQ(6, context.nextInputParamId());
+}
+} // namespace mongo
diff --git a/src/mongo/db/matcher/expression_type.h b/src/mongo/db/matcher/expression_type.h
index 437390c4139..889c9f8da66 100644
--- a/src/mongo/db/matcher/expression_type.h
+++ b/src/mongo/db/matcher/expression_type.h
@@ -75,6 +75,9 @@ public:
if (getTag()) {
expr->setTag(getTag()->clone());
}
+ if (getInputParamId()) {
+ expr->setInputParamId(*getInputParamId());
+ }
return expr;
}
@@ -122,6 +125,14 @@ public:
return _typeSet;
}
+ void setInputParamId(InputParamId paramId) {
+ _inputParamId = paramId;
+ }
+
+ boost::optional<InputParamId> getInputParamId() const {
+ return _inputParamId;
+ }
+
private:
ExpressionOptimizerFunc getOptimizer() const final {
return [](std::unique_ptr<MatchExpression> expression) { return expression; };
@@ -129,6 +140,8 @@ private:
// The set of matching types.
MatcherTypeSet _typeSet;
+
+ boost::optional<InputParamId> _inputParamId;
};
class TypeMatchExpression final : public TypeMatchExpressionBase<TypeMatchExpression> {
diff --git a/src/mongo/db/matcher/expression_where.cpp b/src/mongo/db/matcher/expression_where.cpp
index b94a4059144..4e5dcdc0060 100644
--- a/src/mongo/db/matcher/expression_where.cpp
+++ b/src/mongo/db/matcher/expression_where.cpp
@@ -69,6 +69,9 @@ unique_ptr<MatchExpression> WhereMatchExpression::shallowClone() const {
if (getTag()) {
e->setTag(getTag()->clone());
}
+ if (getInputParamId()) {
+ e->setInputParamId(*getInputParamId());
+ }
return e;
}
} // namespace mongo
diff --git a/src/mongo/db/matcher/expression_where.h b/src/mongo/db/matcher/expression_where.h
index 84142435b16..a8a8090616e 100644
--- a/src/mongo/db/matcher/expression_where.h
+++ b/src/mongo/db/matcher/expression_where.h
@@ -56,12 +56,22 @@ public:
return _jsFunction;
}
+ void setInputParamId(InputParamId paramId) {
+ _inputParamId = paramId;
+ }
+
+ boost::optional<InputParamId> getInputParamId() const {
+ return _inputParamId;
+ }
+
private:
std::string _dbName;
OperationContext* const _opCtx;
JsFunction _jsFunction;
+
+ boost::optional<InputParamId> _inputParamId;
};
} // namespace mongo