diff options
Diffstat (limited to 'src/mongo/db/pipeline')
-rw-r--r-- | src/mongo/db/pipeline/expression.cpp | 444 | ||||
-rw-r--r-- | src/mongo/db/pipeline/expression.h | 99 | ||||
-rw-r--r-- | src/mongo/db/pipeline/expression_test.cpp | 156 |
3 files changed, 436 insertions, 263 deletions
diff --git a/src/mongo/db/pipeline/expression.cpp b/src/mongo/db/pipeline/expression.cpp index 364f8804cd9..ec0a0d8f435 100644 --- a/src/mongo/db/pipeline/expression.cpp +++ b/src/mongo/db/pipeline/expression.cpp @@ -1,3 +1,4 @@ + /** * Copyright (C) 2018-present MongoDB, Inc. * @@ -35,7 +36,6 @@ #include <algorithm> #include <boost/algorithm/string.hpp> #include <cstdio> -#include <pcre.h> #include <pcrecpp.h> #include <vector> @@ -5669,213 +5669,246 @@ Value ExpressionConvert::performConversion(BSONType targetType, Value inputValue BSONType inputType = inputValue.getType(); return table.findConversionFunc(inputType, targetType)(getExpressionContext(), inputValue); } - namespace { -class RegexMatchHandler { -public: - RegexMatchHandler(const Value& inputExpr) : _pcre(nullptr), _nullish(false) { - _validateInputAndExtractElements(inputExpr); - _compile(regex_util::flags2PcreOptions(_options, false).all_options()); - } +boost::optional<Value> extractValueFromConstantExpression( + const std::string& fieldName, + const std::vector<std::pair<std::string, boost::intrusive_ptr<Expression>>>& childExpressions) { + // Find the element with the fieldName. + auto expressionPairItr = std::find_if( + childExpressions.begin(), childExpressions.end(), [&](const auto& childExpression) { + return childExpression.first == fieldName; + }); - ~RegexMatchHandler() { - if (_pcre != nullptr) { - pcre_free(_pcre); - } + // If the field doesn't exists it is still eligible for optimization. + if (expressionPairItr == childExpressions.end()) { + return Value(BSONNULL); } - int execute(int startBytePos) { - int execResult = pcre_exec(_pcre, - 0, - _input.c_str(), - _input.size(), - startBytePos, - 0, // No need to overwrite the options set during pcre_compile. - &_capturesBuffer.front(), - _capturesBuffer.size()); - // The 'execResult' will be (_numCaptures + 1) if there is a match, -1 if there is no - // match, negative if there is an error during execution, and zero if _capturesBuffer's - // capacity is not sufficient to hold all the results. The latter scenario should never - // occur. - uassert( - 51156, - str::stream() << "Error occurred while executing the regular expression. Result code:" - << execResult, - execResult == -1 || execResult == (_numCaptures + 1)); - return execResult; + // If the field exists and not null/constant, we cannot optimize. + if (!ExpressionConstant::isNullOrConstant(expressionPairItr->second)) { + return boost::none; } + auto* expression = expressionPairItr->second.get(); + return dynamic_cast<ExpressionConstant*>(expression)->getValue(); +} - /** - * The function will match '_input' string based on the regex pattern present in '_pcre'. If - * there is a match, the function will return a 'Value' object encapsulating the matched string, - * the code point index of the matched string and a vector representing all the captured - * substrings. The function will also update the parameters 'startBytePos' and - * 'startCodePointPos' to the corresponding new indices. If there is no match, the function will - * return null 'Value' object. - */ - Value nextMatch(int* startBytePos, int* startCodePointPos) { - invariant(startBytePos != nullptr && startCodePointPos != nullptr); - - // Use input as StringData throughout the function to avoid copying the string on 'substr' - // calls. - StringData input = _input; - int execResult = execute(*startBytePos); - // No match. - if (execResult < 0) { - return Value(BSONNULL); - } +} // namespace - // The first and second entries of the '_capturesBuffer' will have the start and limit - // indices of the matched string, as byte offsets. '(limit - startIndex)' would be the - // length of the captured string. - const int matchStartByteIndex = _capturesBuffer[0]; - StringData matchedStr = - input.substr(matchStartByteIndex, _capturesBuffer[1] - matchStartByteIndex); - // We iterate through the input string's contents preceding the match index, in order to - // convert the byte offset to a code point offset. - for (int byteIx = *startBytePos; byteIx < matchStartByteIndex; ++(*startCodePointPos)) { - byteIx += getCodePointLength(input[byteIx]); - } - // Set the start index for match to the new one. - *startBytePos = matchStartByteIndex; - - std::vector<Value> captures; - captures.reserve(_numCaptures); - // The next '2 * numCaptures' entries (after the first two entries) of '_capturesBuffer' - // will hold the start index and limit pairs, for each of the capture groups. We skip the - // first two elements and start iteration from 3rd element so that we only construct the - // strings for capture groups. - for (int i = 0; i < _numCaptures; ++i) { - const int start = _capturesBuffer[2 * (i + 1)]; - const int limit = _capturesBuffer[2 * (i + 1) + 1]; - captures.push_back(Value(input.substr(start, limit - start))); +void RegexMatchHandler::optimize(boost::intrusive_ptr<Expression> expression) { + auto optimizedExpr = expression->optimize(); + + // If 'input', 'regex' and 'options' are null/constant then 'optimize()' will convert the + // object to an expression of type 'ExpressionConstant'. + if (auto* exprObj = dynamic_cast<ExpressionConstant*>(optimizedExpr.get())) { + _initialExecStateForConstantRegex = buildInitialState(exprObj->getValue()); + } else if (auto* exprObj = dynamic_cast<ExpressionObject*>(optimizedExpr.get())) { + // Extract the children and check for constant 'regex' and 'options'. + auto& children = exprObj->getChildExpressions(); + auto regex = extractValueFromConstantExpression("regex", children); + auto options = extractValueFromConstantExpression("options", children); + + // If both 'regex' and 'options' are null/constant, we can pre-compile the execution state. + if (regex && options) { + RegexExecutionState executionState; + _extractRegexAndOptions(&executionState, *regex, *options); + _compile(&executionState); + _initialExecStateForConstantRegex = std::move(executionState); } - - MutableDocument match; - match.addField("match", Value(matchedStr)); - match.addField("idx", Value(*startCodePointPos)); - match.addField("captures", Value(captures)); - return match.freezeToValue(); } +} - int numCaptures() { - return _numCaptures; - } +RegexMatchHandler::RegexExecutionState RegexMatchHandler::buildInitialState( + const Value& inputExpr) const { + uassert(51103, + str::stream() << "expression expects an object of named arguments, but found type " + << inputExpr.getType(), + inputExpr.getType() == BSONType::Object); + Value textInput = inputExpr.getDocument().getField("input"); + Value regexPattern = inputExpr.getDocument().getField("regex"); + Value regexOptions = inputExpr.getDocument().getField("options"); + + auto executionState = _initialExecStateForConstantRegex.value_or(RegexExecutionState()); - bool nullish() { - return _nullish; + // The 'input' parameter can be a variable and needs to be extracted from the expression + // document even when '_preExecutionState' is present. + _extractInputField(&executionState, textInput); + + // If we have a prebuilt execution state, then the 'regex' and 'options' fields are constant + // values, and we do not need to re-compile them. + if (!hasConstantRegex()) { + _extractRegexAndOptions(&executionState, regexPattern, regexOptions); + _compile(&executionState); } - StringData getInput() { - return _input; + return executionState; +} + +int RegexMatchHandler::execute(RegexExecutionState* regexState) const { + invariant(regexState); + invariant(!regexState->nullish()); + invariant(regexState->pcrePtr); + + int execResult = pcre_exec(regexState->pcrePtr.get(), + 0, + regexState->input->c_str(), + regexState->input->size(), + regexState->startBytePos, + 0, // No need to overwrite the options set during pcre_compile. + &(regexState->capturesBuffer.front()), + regexState->capturesBuffer.size()); + // The 'execResult' will be (numCaptures + 1) if there is a match, -1 if there is no match, + // negative (other than -1) if there is an error during execution, and zero if capturesBuffer's + // capacity is not sufficient to hold all the results. The latter scenario should never occur. + uassert(51156, + str::stream() << "Error occurred while executing the regular expression. Result code:" + << execResult, + execResult == -1 || execResult == (regexState->numCaptures + 1)); + return execResult; +} + +Value RegexMatchHandler::nextMatch(RegexExecutionState* regexState) const { + int execResult = execute(regexState); + + // No match. + if (execResult < 0) { + return Value(BSONNULL); } -private: - RegexMatchHandler(const RegexMatchHandler&) = delete; - - void _compile(const int pcreOptions) { - const char* compile_error; - int eoffset; - // The C++ interface pcreccp.h doesn't have a way to capture the matched string (or the - // index of the match). So we are using the C interface. First we compile all the regex - // options to generate pcre object, which will later be used to match against the input - // string. - _pcre = pcre_compile(_pattern.c_str(), pcreOptions, &compile_error, &eoffset, nullptr); - if (_pcre == nullptr) { - uasserted(51111, str::stream() << "Invalid Regex: " << compile_error); - } + // Use 'input' as StringData throughout the function to avoid copying the string on 'substr' + // calls. + StringData input = *(regexState->input); - // Calculate the number of capture groups present in '_pattern' and store in '_numCaptures'. - int pcre_retval = pcre_fullinfo(_pcre, NULL, PCRE_INFO_CAPTURECOUNT, &_numCaptures); - invariant(pcre_retval == 0); - - // The first two-thirds of the vector is used to pass back captured substrings' start and - // limit indexes. The remaining third of the vector is used as workspace by pcre_exec() - // while matching capturing subpatterns, and is not available for passing back information. - // pcre_compile will error if there are too many capture groups in the pattern. As long as - // this memory is allocated after compile, the amount of memory allocated will not be too - // high. - _capturesBuffer = std::vector<int>((1 + _numCaptures) * 3); - } - - void _validateInputAndExtractElements(const Value& inputExpr) { - uassert(51103, - str::stream() << "$regexFind expects an object of named arguments, but found type " - << inputExpr.getType(), - inputExpr.getType() == BSONType::Object); - Value textInput = inputExpr.getDocument().getField("input"); - Value regexPattern = inputExpr.getDocument().getField("regex"); - Value regexOptions = inputExpr.getDocument().getField("options"); - - uassert(51104, - "'input' field should be of type string", - textInput.nullish() || textInput.getType() == BSONType::String); - uassert(51105, - "'regex' field should be of type string or regex", - regexPattern.nullish() || regexPattern.getType() == BSONType::String || - regexPattern.getType() == BSONType::RegEx); - uassert(51106, - "'options' should be of type string", - regexOptions.nullish() || regexOptions.getType() == BSONType::String); - - // If either the text input or regex pattern is nullish, then we consider the operation as a - // whole nullish. - _nullish = textInput.nullish() || regexPattern.nullish(); - - if (textInput.getType() == BSONType::String) { - _input = textInput.getString(); - } + // The first and second entries of the 'capturesBuffer' will have the start and (end+1) indices + // of the matched string, as byte offsets. '(limit - startIndex)' would be the length of the + // captured string. + const int matchStartByteIndex = regexState->capturesBuffer[0]; + StringData matchedStr = + input.substr(matchStartByteIndex, regexState->capturesBuffer[1] - matchStartByteIndex); - // The 'regex' field can be a RegEx object and may have its own options... - if (regexPattern.getType() == BSONType::RegEx) { - StringData regexFlags = regexPattern.getRegexFlags(); - _pattern = regexPattern.getRegex(); - uassert(51107, - str::stream() - << "Found regex option(s) specified in both 'regex' and 'option' fields", - regexOptions.nullish() || regexFlags.empty()); - if (!regexFlags.empty()) { - _options = regexFlags.toString(); - } - } else if (regexPattern.getType() == BSONType::String) { - // ...or it can be a string field with options specified separately. - _pattern = regexPattern.getString(); - } - // If 'options' is non-null, we must extract and validate its contents even if - // 'regexPattern' is nullish. - if (!regexOptions.nullish()) { - _options = regexOptions.getString(); + // We iterate through the input string's contents preceding the match index, in order to convert + // the byte offset to a code point offset. + for (int byteIx = regexState->startBytePos; byteIx < matchStartByteIndex; + ++(regexState->startCodePointPos)) { + byteIx += getCodePointLength(input[byteIx]); + } + + // Set the start index for match to the new one. + regexState->startBytePos = matchStartByteIndex; + + std::vector<Value> captures; + captures.reserve(regexState->numCaptures); + + // The next '2 * numCaptures' entries (after the first two entries) of 'capturesBuffer' will + // hold the start index and limit pairs, for each of the capture groups. We skip the first two + // elements and start iteration from 3rd element so that we only construct the strings for + // capture groups. + for (int i = 0; i < regexState->numCaptures; ++i) { + const int start = regexState->capturesBuffer[2 * (i + 1)]; + const int limit = regexState->capturesBuffer[2 * (i + 1) + 1]; + captures.push_back(Value(input.substr(start, limit - start))); + } + + MutableDocument match; + match.addField("match", Value(matchedStr)); + match.addField("idx", Value(regexState->startCodePointPos)); + match.addField("captures", Value(captures)); + return match.freezeToValue(); +} + +void RegexMatchHandler::_compile(RegexExecutionState* executionState) const { + const auto pcreOptions = + regex_util::flags2PcreOptions(executionState->options, false).all_options(); + + if (!executionState->pattern) { + return; + } + const char* compile_error; + int eoffset; + + // The C++ interface pcreccp.h doesn't have a way to capture the matched string (or the index of + // the match). So we are using the C interface. First we compile all the regex options to + // generate pcre object, which will later be used to match against the input string. + executionState->pcrePtr = std::shared_ptr<pcre>( + pcre_compile( + executionState->pattern->c_str(), pcreOptions, &compile_error, &eoffset, nullptr), + pcre_free); + uassert(51111, str::stream() << "Invalid Regex: " << compile_error, executionState->pcrePtr); + + // Calculate the number of capture groups present in 'pattern' and store in 'numCaptures'. + const int pcre_retval = pcre_fullinfo( + executionState->pcrePtr.get(), NULL, PCRE_INFO_CAPTURECOUNT, &executionState->numCaptures); + invariant(pcre_retval == 0); + + // The first two-thirds of the vector is used to pass back captured substrings' start and + // (end+1) indexes. The remaining third of the vector is used as workspace by pcre_exec() while + // matching capturing subpatterns, and is not available for passing back information. + // pcre_compile will error if there are too many capture groups in the pattern. As long as this + // memory is allocated after compile, the amount of memory allocated will not be too high. + executionState->capturesBuffer.resize((1 + executionState->numCaptures) * 3); +} + +void RegexMatchHandler::_extractInputField(RegexExecutionState* executionState, + const Value& textInput) const { + uassert(51104, + "'input' field should be of type string", + textInput.nullish() || textInput.getType() == BSONType::String); + if (textInput.getType() == BSONType::String) { + executionState->input = textInput.getString(); + } +} + +void RegexMatchHandler::_extractRegexAndOptions(RegexExecutionState* executionState, + const Value& regexPattern, + const Value& regexOptions) const { + uassert(51105, + "'regex' field should be of type string or regex", + regexPattern.nullish() || regexPattern.getType() == BSONType::String || + regexPattern.getType() == BSONType::RegEx); + uassert(51106, + "'options' should be of type string", + regexOptions.nullish() || regexOptions.getType() == BSONType::String); + + // The 'regex' field can be a RegEx object and may have its own options... + if (regexPattern.getType() == BSONType::RegEx) { + StringData regexFlags = regexPattern.getRegexFlags(); + executionState->pattern = regexPattern.getRegex(); + uassert( + 51107, + str::stream() << "Found regex option(s) specified in both 'regex' and 'option' fields", + regexOptions.nullish() || regexFlags.empty()); + if (!regexFlags.empty()) { + executionState->options = regexFlags.toString(); } - uassert(51109, - "Regular expression cannot contain an embedded null byte", - _pattern.find('\0', 0) == string::npos); - uassert(51110, - "Regular expression options string cannot contain an embedded null byte", - _options.find('\0', 0) == string::npos); - } - - pcre* _pcre; - // Number of capture groups present in '_pattern'. - int _numCaptures; - // Holds the start and limit indices of match and captures for the current match. - std::vector<int> _capturesBuffer; - std::string _input; - std::string _pattern; - std::string _options; - bool _nullish; -}; + } else if (regexPattern.getType() == BSONType::String) { + // ...or it can be a string field with options specified separately. + executionState->pattern = regexPattern.getString(); + } -} // namespace + // If 'options' is non-null, we must extract and validate its contents even if 'regexPattern' is + // nullish. + if (!regexOptions.nullish()) { + executionState->options = regexOptions.getString(); + } + uassert(51109, + "Regular expression cannot contain an embedded null byte", + !executionState->pattern || executionState->pattern->find('\0', 0) == string::npos); + uassert(51110, + "Regular expression options string cannot contain an embedded null byte", + executionState->options.find('\0', 0) == string::npos); +} -Value ExpressionRegexFind::evaluate(const Document& root) const { +boost::intrusive_ptr<Expression> ExpressionRegexFind::optimize() { + _handler.optimize(vpOperand[0]); + return this; +} - RegexMatchHandler regex(vpOperand[0]->evaluate(root)); - if (regex.nullish()) { +Value ExpressionRegexFind::evaluate(const Document& root) const { + auto executionState = _handler.buildInitialState(vpOperand[0]->evaluate(root)); + if (executionState.nullish()) { return Value(BSONNULL); } - int startByteIndex = 0, startCodePointIndex = 0; - return regex.nextMatch(&startByteIndex, &startCodePointIndex); + return _handler.nextMatch(&executionState); } REGISTER_EXPRESSION(regexFind, ExpressionRegexFind::parse); @@ -5883,21 +5916,24 @@ const char* ExpressionRegexFind::getOpName() const { return "$regexFind"; } -Value ExpressionRegexFindAll::evaluate(const Document& root) const { +boost::intrusive_ptr<Expression> ExpressionRegexFindAll::optimize() { + _handler.optimize(vpOperand[0]); + return this; +} +Value ExpressionRegexFindAll::evaluate(const Document& root) const { std::vector<Value> output; - RegexMatchHandler regex(vpOperand[0]->evaluate(root)); - if (regex.nullish()) { + auto executionState = _handler.buildInitialState(vpOperand[0]->evaluate(root)); + if (executionState.nullish()) { return Value(output); } - int startByteIndex = 0, startCodePointIndex = 0; - StringData input = regex.getInput(); + StringData input = *(executionState.input); size_t totalDocSize = 0; // Using do...while loop because, when input is an empty string, we still want to see if there // is a match. do { - auto matchObj = regex.nextMatch(&startByteIndex, &startCodePointIndex); + auto matchObj = _handler.nextMatch(&executionState); if (matchObj.getType() == BSONType::jstNULL) { break; } @@ -5913,19 +5949,22 @@ Value ExpressionRegexFindAll::evaluate(const Document& root) const { // the character at startByteIndex matches the regex, we cannot return it since we are // already returing an empty string starting at this index. So we move on to the next // byte index. - startByteIndex += getCodePointLength(input[startByteIndex]); - ++startCodePointIndex; + executionState.startBytePos += getCodePointLength(input[executionState.startBytePos]); + ++executionState.startCodePointPos; continue; } - // We don't want any overlapping sub-strings. So we move 'startByteIndex' to point to the + + // We don't want any overlapping sub-strings. So we move 'startBytePos' to point to the // byte after 'matchStr'. We move the code point index also correspondingly. - startByteIndex += matchStr.size(); - for (size_t byteIx = 0; byteIx < matchStr.size(); ++startCodePointIndex) { + executionState.startBytePos += matchStr.size(); + for (size_t byteIx = 0; byteIx < matchStr.size(); ++executionState.startCodePointPos) { byteIx += getCodePointLength(matchStr[byteIx]); } - invariant(startByteIndex > 0 && startCodePointIndex > 0 && - startCodePointIndex <= startByteIndex); - } while (static_cast<size_t>(startByteIndex) < input.size()); + + invariant(executionState.startBytePos > 0); + invariant(executionState.startCodePointPos > 0); + invariant(executionState.startCodePointPos <= executionState.startBytePos); + } while (static_cast<size_t>(executionState.startBytePos) < input.size()); return Value(output); } @@ -5934,10 +5973,15 @@ const char* ExpressionRegexFindAll::getOpName() const { return "$regexFindAll"; } +boost::intrusive_ptr<Expression> ExpressionRegexMatch::optimize() { + _handler.optimize(vpOperand[0]); + return this; +} + Value ExpressionRegexMatch::evaluate(const Document& root) const { - RegexMatchHandler regex(vpOperand[0]->evaluate(root)); + auto executionState = _handler.buildInitialState(vpOperand[0]->evaluate(root)); // Return output of execute only if regex is not nullish. - return regex.nullish() ? Value(false) : Value(regex.execute(0) > 0); + return executionState.nullish() ? Value(false) : Value(_handler.execute(&executionState) > 0); } REGISTER_EXPRESSION(regexMatch, ExpressionRegexMatch::parse); diff --git a/src/mongo/db/pipeline/expression.h b/src/mongo/db/pipeline/expression.h index 4969e2ea231..ddd77682d36 100644 --- a/src/mongo/db/pipeline/expression.h +++ b/src/mongo/db/pipeline/expression.h @@ -34,6 +34,7 @@ #include <algorithm> #include <boost/intrusive_ptr.hpp> #include <map> +#include <pcre.h> #include <string> #include <vector> @@ -2433,17 +2434,101 @@ private: boost::intrusive_ptr<Expression> _onNull; }; +class RegexMatchHandler { +public: + /** + * Object to hold data that is required by 'RegexMatchHandler' for calling 'execute()' or + * 'nextMatch()'. + */ + struct RegexExecutionState { + boost::optional<std::string> pattern; + std::string options; + int numCaptures = 0; + std::vector<int> capturesBuffer; + /** + * If there is a constant regex, the underlying object of 'pcre' will be owned by + * 'RegexMatchHandler', as part of '_preExecutionState'. If not, it will be owned by + * 'RegexExecutionState'. + */ + std::shared_ptr<pcre> pcrePtr; + boost::optional<std::string> input; + int startBytePos = 0; + int startCodePointPos = 0; + + /** + * If either the text input or regex pattern is nullish, then we consider the operation as a + * whole nullish. + */ + bool nullish() { + return !input || !pattern; + } + }; + + /** + * Checks if there is a match for the given input and pattern that are part of 'executionState'. + * The method will return a positive number if there is a match and '-1' if there is no match. + * Throws 'uassert()' for any errors. + */ + int execute(RegexExecutionState* executionState) const; + + /** + * Finds the next possible match for the given input and pattern that are part of + * 'executionState'. If there is a match, the function will return a 'Value' object + * encapsulating the matched string, the code point index of the matched string and a vector + * representing all the captured substrings. The function will also update the parameters + * 'startBytePos' and 'startCodePointPos' to the corresponding new indices. If there is no + * match, the function will return null 'Value' object. + */ + Value nextMatch(RegexExecutionState* executionState) const; + + /** + * Optimizes '$regex*' expressions. If the expression has a constant 'regex' and 'options' + * fields, then it can be optimized. Stores the optimized regex as part of '_constantRegex' so + * that it can be reused during expression evaluation. + */ + void optimize(boost::intrusive_ptr<Expression> expression); + + /** + * Validates the structure of input passed in 'inputExpr'. If valid, generates an initial + * execution state. This returned object can later be used for calling execute() or nextMatch(). + */ + RegexExecutionState buildInitialState(const Value& inputExpr) const; + bool hasConstantRegex() const { + return _initialExecStateForConstantRegex.has_value(); + } + +private: + void _extractInputField(RegexExecutionState* executionState, const Value& textInput) const; + void _extractRegexAndOptions(RegexExecutionState* executionState, + const Value& regexPattern, + const Value& regexOptions) const; + void _compile(RegexExecutionState* executionState) const; + /** + * This variable will be set when the $regex* expressions have constant values for their 'regex' + * and 'options' fields, allowing us to pre-compile the regex and re-use it across the + * Expression's lifetime. + */ + boost::optional<RegexExecutionState> _initialExecStateForConstantRegex; +}; + class ExpressionRegexFind final : public ExpressionFixedArity<ExpressionRegexFind, 1> { public: explicit ExpressionRegexFind(const boost::intrusive_ptr<ExpressionContext>& expCtx) : ExpressionFixedArity<ExpressionRegexFind, 1>(expCtx) {} Value evaluate(const Document& root) const final; + boost::intrusive_ptr<Expression> optimize() final; const char* getOpName() const final; void acceptVisitor(ExpressionVisitor* visitor) final { return visitor->visit(this); } + bool hasConstantRegex() const { + return _handler.hasConstantRegex(); + } + +private: + RegexMatchHandler _handler; }; class ExpressionRegexFindAll final : public ExpressionFixedArity<ExpressionRegexFindAll, 1> { @@ -2452,10 +2537,17 @@ public: : ExpressionFixedArity<ExpressionRegexFindAll, 1>(expCtx) {} Value evaluate(const Document& root) const final; + boost::intrusive_ptr<Expression> optimize() final; const char* getOpName() const final; void acceptVisitor(ExpressionVisitor* visitor) final { return visitor->visit(this); } + bool hasConstantRegex() const { + return _handler.hasConstantRegex(); + } + +private: + RegexMatchHandler _handler; }; class ExpressionRegexMatch final : public ExpressionFixedArity<ExpressionRegexMatch, 1> { @@ -2464,10 +2556,17 @@ public: : ExpressionFixedArity<ExpressionRegexMatch, 1>(expCtx) {} Value evaluate(const Document& root) const final; + boost::intrusive_ptr<Expression> optimize() final; const char* getOpName() const final; void acceptVisitor(ExpressionVisitor* visitor) final { return visitor->visit(this); } + bool hasConstantRegex() const { + return _handler.hasConstantRegex(); + } + +private: + RegexMatchHandler _handler; }; } diff --git a/src/mongo/db/pipeline/expression_test.cpp b/src/mongo/db/pipeline/expression_test.cpp index 5028d870e24..c1ba68f4cd2 100644 --- a/src/mongo/db/pipeline/expression_test.cpp +++ b/src/mongo/db/pipeline/expression_test.cpp @@ -5956,24 +5956,93 @@ TEST(GetComputedPathsTest, ExpressionMapNotConsideredRenameWithDottedInputPath) namespace ExpressionRegexTest { -TEST(ExpressionRegexFindTest, BasicTest) { - Value input(fromjson("{input: 'asdf', regex: '^as' }")); - BSONObj expectedOut(fromjson("{match: 'as', idx:0, captures:[]}")); - intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); - ExpressionRegexFind regexF(expCtx); - regexF.addOperand(ExpressionConstant::create(expCtx, input)); - Value output = regexF.evaluate(Document()); - ASSERT_BSONOBJ_EQ(toBson(output.getDocument()), expectedOut); +class ExpressionRegexTest { +public: + template <typename SubClass, int N> + static intrusive_ptr<Expression> generateOptimizedExpression(const BSONObj& input) { + intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); + auto expression = ExpressionFixedArity<SubClass, N>::parse( + expCtx, input.firstElement(), expCtx->variablesParseState); + return expression->optimize(); + } + + static void testAllExpressions(const BSONObj& input, + bool optimized, + const std::vector<Value>& expectedFindAllOutput) { + { + // For $regexFindAll. + auto expression = generateOptimizedExpression<ExpressionRegexFindAll, 1>(input); + auto regexFindAllExpr = dynamic_cast<ExpressionRegexFindAll*>(expression.get()); + ASSERT_EQ(regexFindAllExpr->hasConstantRegex(), optimized); + Value output = regexFindAllExpr->evaluate(Document()); + ASSERT_VALUE_EQ(output, Value(expectedFindAllOutput)); + } + + { + // For $regexFind. + auto expression = generateOptimizedExpression<ExpressionRegexFind, 1>(input); + auto regexFindExpr = dynamic_cast<ExpressionRegexFind*>(expression.get()); + ASSERT_EQ(regexFindExpr->hasConstantRegex(), optimized); + Value output = regexFindExpr->evaluate(Document()); + ASSERT_VALUE_EQ( + output, expectedFindAllOutput.empty() ? Value(BSONNULL) : expectedFindAllOutput[0]); + } + + { + // For $regexMatch. + auto expression = generateOptimizedExpression<ExpressionRegexMatch, 1>(input); + auto regexMatchExpr = dynamic_cast<ExpressionRegexMatch*>(expression.get()); + ASSERT_EQ(regexMatchExpr->hasConstantRegex(), optimized); + Value output = regexMatchExpr->evaluate(Document()); + ASSERT_VALUE_EQ(output, expectedFindAllOutput.empty() ? Value(false) : Value(true)); + } + } +}; + +TEST(ExpressionRegexTest, BasicTest) { + ExpressionRegexTest::testAllExpressions( + fromjson("{$regexFindAll : {input: 'asdf', regex: '^as' }}"), + true, + {Value(fromjson("{match: 'as', idx:0, captures:[]}"))}); } -TEST(ExpressionRegexFindTest, ExtendedRegexOptions) { - Value input(fromjson("{input: 'FirstLine\\nSecondLine', regex: '^second' , options: 'mi'}")); - BSONObj expectedOut(fromjson("{match: 'Second', idx:10, captures:[]}")); - intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); - ExpressionRegexFind regexF(expCtx); - regexF.addOperand(ExpressionConstant::create(expCtx, input)); - Value output = regexF.evaluate(Document()); - ASSERT_BSONOBJ_EQ(toBson(output.getDocument()), expectedOut); +TEST(ExpressionRegexTest, ExtendedRegexOptions) { + ExpressionRegexTest::testAllExpressions( + fromjson("{$regexFindAll : {input: 'FirstLine\\nSecondLine', regex: " + "'^second' , options: 'mi'}}"), + true, + {Value(fromjson("{match: 'Second', idx:10, captures:[]}"))}); +} + +TEST(ExpressionRegexTest, MultipleMatches) { + ExpressionRegexTest::testAllExpressions( + fromjson("{$regexFindAll : {input: 'a1b2c3', regex: '([a-c][1-3])' }}"), + true, + {Value(fromjson("{match: 'a1', idx:0, captures:['a1']}")), + Value(fromjson("{match: 'b2', idx:2, captures:['b2']}")), + Value(fromjson("{match: 'c3', idx:4, captures:['c3']}"))}); +} + +TEST(ExpressionRegexTest, OptimizPatternWhenInputIsVariable) { + ExpressionRegexTest::testAllExpressions( + fromjson("{$regexFindAll : {input: '$input', regex: '([a-c][1-3])' }}"), true, {}); +} + +TEST(ExpressionRegexTest, NoOptimizePatternWhenRegexVariable) { + ExpressionRegexTest::testAllExpressions( + fromjson("{$regexFindAll : {input: 'asdf', regex: '$regex' }}"), false, {}); +} + +TEST(ExpressionRegexTest, NoOptimizePatternWhenOptionsVariable) { + ExpressionRegexTest::testAllExpressions( + fromjson("{$regexFindAll : {input: 'asdf', regex: '(asdf)', options: '$options' }}"), + false, + {Value(fromjson("{match: 'asdf', idx:0, captures:['asdf']}"))}); +} + +TEST(ExpressionRegexTest, NoMatch) { + ExpressionRegexTest::testAllExpressions( + fromjson("{$regexFindAll : {input: 'a1b2c3', regex: 'ab' }}"), true, {}); } TEST(ExpressionRegexFindTest, FailureCase) { @@ -5985,27 +6054,6 @@ TEST(ExpressionRegexFindTest, FailureCase) { ASSERT_THROWS_CODE(regexF.evaluate(Document()), DBException, 51105); } -TEST(ExpressionRegexFindAllTest, MultipleMatches) { - Value input(fromjson("{input: 'a1b2c3', regex: '([a-c][1-3])' }")); - std::vector<Value> expectedOut = {Value(fromjson("{match: 'a1', idx:0, captures:['a1']}")), - Value(fromjson("{match: 'b2', idx:2, captures:['b2']}")), - Value(fromjson("{match: 'c3', idx:4, captures:['c3']}"))}; - intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); - ExpressionRegexFindAll regexF(expCtx); - regexF.addOperand(ExpressionConstant::create(expCtx, input)); - Value output = regexF.evaluate(Document()); - ASSERT_VALUE_EQ(output, Value(expectedOut)); -} - -TEST(ExpressionRegexFindAllTest, NoMatch) { - Value input(fromjson("{input: 'a1b2c3', regex: 'ab' }")); - intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); - ExpressionRegexFindAll regexF(expCtx); - regexF.addOperand(ExpressionConstant::create(expCtx, input)); - Value output = regexF.evaluate(Document()); - ASSERT_VALUE_EQ(output, Value(std::vector<Value>())); -} - TEST(ExpressionRegexFindAllTest, FailureCase) { Value input(fromjson("{input: 'FirstLine\\nSecondLine', regex: '[0-9'}")); intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); @@ -6014,6 +6062,14 @@ TEST(ExpressionRegexFindAllTest, FailureCase) { ASSERT_THROWS_CODE(regexF.evaluate(Document()), DBException, 51111); } +TEST(ExpressionRegexMatchTest, FailureCase) { + Value input(fromjson("{regex: 'valid', input: {invalid : 'input'} , options: 'mi'}")); + intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); + ExpressionRegexMatch regexMatchExpr(expCtx); + regexMatchExpr.addOperand(ExpressionConstant::create(expCtx, input)); + ASSERT_THROWS_CODE(regexMatchExpr.evaluate(Document()), DBException, 51104); +} + TEST(ExpressionRegexFindAllTest, InvalidUTF8InInput) { std::string inputField = "1234 "; // Append an invalid UTF-8 character. @@ -6039,32 +6095,6 @@ TEST(ExpressionRegexFindAllTest, InvalidUTF8InRegex) { ASSERT_THROWS_CODE(regexF.evaluate(Document()), DBException, 51111); } -TEST(ExpressionRegexMatchTest, NoMatch) { - Value input(fromjson("{input: 'asdf', regex: '^sd' }")); - Value expectedOut(false); - intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); - ExpressionRegexMatch regexMatchExpr(expCtx); - regexMatchExpr.addOperand(ExpressionConstant::create(expCtx, input)); - ASSERT_VALUE_EQ(regexMatchExpr.evaluate(Document()), expectedOut); -} - -TEST(ExpressionRegexMatchTest, ExtendedRegexOptions) { - Value input(fromjson("{input: 'FirstLine\\nSecondLine', regex: '^second' , options: 'mi'}")); - Value expectedOut(true); - intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); - ExpressionRegexMatch regexMatchExpr(expCtx); - regexMatchExpr.addOperand(ExpressionConstant::create(expCtx, input)); - ASSERT_VALUE_EQ(regexMatchExpr.evaluate(Document()), expectedOut); -} - -TEST(ExpressionRegexMatchTest, FailureCase) { - Value input(fromjson("{regex: 'valid', input: {invalid : 'input'} , options: 'mi'}")); - intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); - ExpressionRegexMatch regexMatchExpr(expCtx); - regexMatchExpr.addOperand(ExpressionConstant::create(expCtx, input)); - ASSERT_THROWS_CODE(regexMatchExpr.evaluate(Document()), DBException, 51104); -} - } // namespace ExpressionRegexTest class All : public Suite { |