summaryrefslogtreecommitdiff
path: root/src/mongo/db/pipeline
diff options
context:
space:
mode:
authorArun Banala <arun.banala@mongodb.com>2019-03-08 20:26:09 +0000
committerArun Banala <arun.banala@mongodb.com>2019-03-20 15:09:26 +0000
commit12a560bff2911a29103d05071e260060c77263eb (patch)
tree12e3aee5e0bd4ead46f309b12d7eb598f8277df7 /src/mongo/db/pipeline
parent38c94f316b167e4b54b54ba8d12dbec33c7c5165 (diff)
downloadmongo-12a560bff2911a29103d05071e260060c77263eb.tar.gz
SERVER-39696 Implement $regexFindAll
Diffstat (limited to 'src/mongo/db/pipeline')
-rw-r--r--src/mongo/db/pipeline/expression.cpp338
-rw-r--r--src/mongo/db/pipeline/expression.h11
-rw-r--r--src/mongo/db/pipeline/expression_test.cpp59
3 files changed, 290 insertions, 118 deletions
diff --git a/src/mongo/db/pipeline/expression.cpp b/src/mongo/db/pipeline/expression.cpp
index d0d54cd7e14..f2bd565989e 100644
--- a/src/mongo/db/pipeline/expression.cpp
+++ b/src/mongo/db/pipeline/expression.cpp
@@ -5661,138 +5661,248 @@ Value ExpressionConvert::performConversion(BSONType targetType, Value inputValue
namespace {
-Value generateRegexCapturesAndMatches(StringData pattern,
- const int numCaptures,
- const pcrecpp::RE_Options& options,
- StringData input,
- int startBytePos,
- int startCodePointPos) {
-
- const auto pcreOptions = options.all_options();
- // The first two-thirds of the vector is used to pass back captured substrings' start and limit
- // indexes. The remaining third of the vector is used as workspace by pcre_exec() while matching
- // capturing subpatterns, and is not available for passing back information.
- const size_t sizeOfOVector = (1 + numCaptures) * 3;
- const char* compile_error;
- int eoffset;
-
- // The C++ interface pcreccp.h doesn't have a way to capture the matched string (or the index of
- // the match). So we are using the C interface. First we compile all the regex options to
- // generate pcre object, which will later be used to match against the input string.
- pcre* pcre = pcre_compile(pattern.rawData(), pcreOptions, &compile_error, &eoffset, nullptr);
- if (pcre == nullptr) {
- uasserted(51111, str::stream() << "Invalid Regex: " << compile_error);
- }
-
- // TODO: Evaluate the upper bound for this array and fail the request if numCaptures are higher
- // than the limit (SERVER-37848).
- std::vector<int> outVector(sizeOfOVector);
- const int out = pcre_exec(pcre,
- 0,
- input.rawData(),
- input.size(),
- startBytePos,
- 0, // No need to overwrite the options set during pcre_compile.
- &outVector.front(),
- sizeOfOVector);
- (*pcre_free)(pcre);
- // The 'out' parameter will be zero if outVector's size is not big enough to hold all the
- // captures, which should never be the case.
- invariant(out != 0);
-
- // No match.
- if (out < 0) {
- return Value(BSONNULL);
+class RegexMatchHandler {
+public:
+ RegexMatchHandler(const Value& inputExpr) : _pcre(nullptr), _nullish(false) {
+ _validateInputAndExtractElements(inputExpr);
+ _compile(regex_util::flags2PcreOptions(_options, false).all_options());
}
- // The first and second entires of the outVector have the start and limit indices of the matched
- // string. as byte offsets.
- const int matchStartByteIndex = outVector[0];
- // We iterate through the input string's contents preceding the match index, in order to convert
- // the byte offset to a code point offset.
- for (int byteIx = startBytePos; byteIx < matchStartByteIndex; ++startCodePointPos) {
- byteIx += getCodePointLength(input[byteIx]);
+ ~RegexMatchHandler() {
+ if (_pcre != nullptr) {
+ pcre_free(_pcre);
+ }
}
- StringData matchedStr = input.substr(outVector[0], outVector[1] - outVector[0]);
- std::vector<Value> captures;
- // The next 2 * numCaptures entries hold the start index and limit pairs, for each of the
- // capture groups. We skip the first two elements and start iteration from 3rd element so that
- // we only construct the strings for capture groups.
- for (int i = 0; i < numCaptures; i++) {
- const int start = outVector[2 * (i + 1)];
- const int limit = outVector[2 * (i + 1) + 1];
- captures.push_back(Value(input.substr(start, limit - start)));
- }
+ /**
+ * The function will match '_input' string based on the regex pattern present in '_pcre'. If
+ * there is a match, the function will return a 'Value' object encapsulating the matched string,
+ * the code point index of the matched string and a vector representing all the captured
+ * substrings. The function will also update the parameters 'startBytePos' and
+ * 'startCodePointPos' to the corresponding new indices. If there is no match, the function will
+ * return null 'Value' object.
+ */
+ Value nextMatch(int* startBytePos, int* startCodePointPos) {
+ invariant(startBytePos != nullptr && startCodePointPos != nullptr);
+
+ // Use input as StringData throughout the function to avoid copying the string on 'substr'
+ // calls.
+ StringData input = _input;
+ int execResult = pcre_exec(_pcre,
+ 0,
+ input.rawData(),
+ input.size(),
+ *startBytePos,
+ 0, // No need to overwrite the options set during pcre_compile.
+ &_capturesBuffer.front(),
+ _capturesBuffer.size());
+ // No match.
+ if (execResult < 0) {
+ return Value(BSONNULL);
+ }
+ // The 'execResult' will be zero if _capturesBuffer's size is not big enough to hold all
+ // the captures, which should never be the case.
+ invariant(execResult == _numCaptures + 1);
+
+ // The first and second entries of the '_capturesBuffer' will have the start and limit
+ // indices of the matched string, as byte offsets. '(limit - startIndex)' would be the
+ // length of the captured string.
+ const int matchStartByteIndex = _capturesBuffer[0];
+ StringData matchedStr =
+ input.substr(matchStartByteIndex, _capturesBuffer[1] - matchStartByteIndex);
+ // We iterate through the input string's contents preceding the match index, in order to
+ // convert the byte offset to a code point offset.
+ for (int byteIx = *startBytePos; byteIx < matchStartByteIndex; ++(*startCodePointPos)) {
+ byteIx += getCodePointLength(input[byteIx]);
+ }
+ // Set the start index for match to the new one.
+ *startBytePos = matchStartByteIndex;
+
+ std::vector<Value> captures;
+ captures.reserve(_numCaptures);
+ // The next '2 * numCaptures' entries (after the first two entries) of '_capturesBuffer'
+ // will hold the start index and limit pairs, for each of the capture groups. We skip the
+ // first two elements and start iteration from 3rd element so that we only construct the
+ // strings for capture groups.
+ for (int i = 0; i < _numCaptures; ++i) {
+ const int start = _capturesBuffer[2 * (i + 1)];
+ const int limit = _capturesBuffer[2 * (i + 1) + 1];
+ captures.push_back(Value(input.substr(start, limit - start)));
+ }
- MutableDocument match;
- match.addField("match", Value(matchedStr));
- match.addField("idx", Value(startCodePointPos));
- match.addField("captures", Value(captures));
- return match.freezeToValue();
-}
+ MutableDocument match;
+ match.addField("match", Value(matchedStr));
+ match.addField("idx", Value(*startCodePointPos));
+ match.addField("captures", Value(captures));
+ return match.freezeToValue();
+ }
-} // namespace
+ int numCaptures() {
+ return _numCaptures;
+ }
-Value ExpressionRegexFind::evaluate(const Document& root) const {
+ bool nullish() {
+ return _nullish;
+ }
- const Value expr = vpOperand[0]->evaluate(root);
- uassert(51103,
- str::stream() << "$regexFind expects an object of named arguments, but found type "
- << expr.getType(),
- !expr.nullish() && expr.getType() == BSONType::Object);
- Value textInput = expr.getDocument().getField("input");
- Value regexPattern = expr.getDocument().getField("regex");
- Value regexOptions = expr.getDocument().getField("options");
-
- uassert(51104,
- "input field should be of type string",
- textInput.nullish() || textInput.getType() == BSONType::String);
- uassert(51105,
- "regex field should be of type string or regex",
- regexPattern.nullish() || regexPattern.getType() == BSONType::String ||
- regexPattern.getType() == BSONType::RegEx);
- uassert(51106,
- "options should be of type string",
- regexOptions.nullish() || regexOptions.getType() == BSONType::String);
- if (textInput.nullish() || regexPattern.nullish()) {
- return Value(BSONNULL);
+ StringData getInput() {
+ return _input;
}
- StringData pattern, optionFlags;
- // The 'regex' field can be a RegEx object with its own options/options specified separately...
- if (regexPattern.getType() == BSONType::RegEx) {
- StringData regexFlags = regexPattern.getRegexFlags();
- pattern = regexPattern.getRegex();
- uassert(
- 51107,
- str::stream() << "Found regex option(s) specified in both 'regex' and 'option' fields",
- regexOptions.nullish() || regexFlags.empty());
- optionFlags = regexOptions.nullish() ? regexFlags : regexOptions.getStringData();
- } else {
- // ... or it can be a string field with options specified separately.
- pattern = regexPattern.getStringData();
+private:
+ RegexMatchHandler(const RegexMatchHandler&) = delete;
+
+ void _compile(const int pcreOptions) {
+ const char* compile_error;
+ int eoffset;
+ // The C++ interface pcreccp.h doesn't have a way to capture the matched string (or the
+ // index of the match). So we are using the C interface. First we compile all the regex
+ // options to generate pcre object, which will later be used to match against the input
+ // string.
+ _pcre = pcre_compile(_pattern.c_str(), pcreOptions, &compile_error, &eoffset, nullptr);
+ if (_pcre == nullptr) {
+ uasserted(51111, str::stream() << "Invalid Regex: " << compile_error);
+ }
+
+ // Calculate the number of capture groups present in '_pattern' and store in '_numCaptures'.
+ int pcre_retval = pcre_fullinfo(_pcre, NULL, PCRE_INFO_CAPTURECOUNT, &_numCaptures);
+ invariant(pcre_retval == 0);
+
+ // The first two-thirds of the vector is used to pass back captured substrings' start and
+ // limit indexes. The remaining third of the vector is used as workspace by pcre_exec()
+ // while matching capturing subpatterns, and is not available for passing back information.
+ // TODO: Evaluate the upper bound for this array and fail the request if numCaptures are
+ // higher than the limit (SERVER-37848).
+ _capturesBuffer = std::vector<int>((1 + _numCaptures) * 3);
+ }
+
+ void _validateInputAndExtractElements(const Value& inputExpr) {
+ uassert(51103,
+ str::stream() << "$regexFind expects an object of named arguments, but found type "
+ << inputExpr.getType(),
+ inputExpr.getType() == BSONType::Object);
+ Value textInput = inputExpr.getDocument().getField("input");
+ Value regexPattern = inputExpr.getDocument().getField("regex");
+ Value regexOptions = inputExpr.getDocument().getField("options");
+
+ uassert(51104,
+ "'input' field should be of type string",
+ textInput.nullish() || textInput.getType() == BSONType::String);
+ uassert(51105,
+ "'regex' field should be of type string or regex",
+ regexPattern.nullish() || regexPattern.getType() == BSONType::String ||
+ regexPattern.getType() == BSONType::RegEx);
+ uassert(51106,
+ "'options' should be of type string",
+ regexOptions.nullish() || regexOptions.getType() == BSONType::String);
+
+ // If either the text input or regex pattern is nullish, then we consider the operation as a
+ // whole nullish.
+ _nullish = textInput.nullish() || regexPattern.nullish();
+
+ if (textInput.getType() == BSONType::String) {
+ _input = textInput.getString();
+ }
+
+ // The 'regex' field can be a RegEx object and may have its own options...
+ if (regexPattern.getType() == BSONType::RegEx) {
+ StringData regexFlags = regexPattern.getRegexFlags();
+ _pattern = regexPattern.getRegex();
+ uassert(51107,
+ str::stream()
+ << "Found regex option(s) specified in both 'regex' and 'option' fields",
+ regexOptions.nullish() || regexFlags.empty());
+ if (!regexFlags.empty()) {
+ _options = regexFlags.toString();
+ }
+ } else if (regexPattern.getType() == BSONType::String) {
+ // ...or it can be a string field with options specified separately.
+ _pattern = regexPattern.getString();
+ }
+ // If 'options' is non-null, we must extract and validate its contents even if
+ // 'regexPattern' is nullish.
if (!regexOptions.nullish()) {
- optionFlags = regexOptions.getStringData();
+ _options = regexOptions.getString();
}
- }
- uassert(51109,
- "Regular expression cannot contain an embedded null byte",
- pattern.find('\0', 0) == string::npos);
- uassert(51110,
- "Regular expression options string cannot contain an embedded null byte",
- optionFlags.find('\0', 0) == string::npos);
+ uassert(51109,
+ "Regular expression cannot contain an embedded null byte",
+ _pattern.find('\0', 0) == string::npos);
+ uassert(51110,
+ "Regular expression options string cannot contain an embedded null byte",
+ _options.find('\0', 0) == string::npos);
+ }
+
+ pcre* _pcre;
+ // Number of capture groups present in '_pattern'.
+ int _numCaptures;
+ // Holds the start and limit indices of match and captures for the current match.
+ std::vector<int> _capturesBuffer;
+ std::string _input;
+ std::string _pattern;
+ std::string _options;
+ bool _nullish;
+};
- pcrecpp::RE_Options opt = regex_util::flags2PcreOptions(optionFlags, false);
- pcrecpp::RE regex(pattern.rawData(), opt);
- return generateRegexCapturesAndMatches(
- pattern, regex.NumberOfCapturingGroups(), opt, textInput.getStringData(), 0, 0);
-}
+} // namespace
+Value ExpressionRegexFind::evaluate(const Document& root) const {
+
+ RegexMatchHandler regex(vpOperand[0]->evaluate(root));
+ if (regex.nullish()) {
+ return Value(BSONNULL);
+ }
+ int startByteIndex = 0, startCodePointIndex = 0;
+ return regex.nextMatch(&startByteIndex, &startCodePointIndex);
+}
REGISTER_EXPRESSION(regexFind, ExpressionRegexFind::parse);
const char* ExpressionRegexFind::getOpName() const {
return "$regexFind";
}
+Value ExpressionRegexFindAll::evaluate(const Document& root) const {
+
+ std::vector<Value> output;
+ RegexMatchHandler regex(vpOperand[0]->evaluate(root));
+ if (regex.nullish()) {
+ return Value(output);
+ }
+ int startByteIndex = 0, startCodePointIndex = 0;
+ StringData input = regex.getInput();
+
+ // Using do...while loop because, when input is an empty string, we still want to see if there
+ // is a match.
+ do {
+ auto matchObj = regex.nextMatch(&startByteIndex, &startCodePointIndex);
+ if (matchObj.getType() == BSONType::jstNULL) {
+ break;
+ }
+ output.push_back(matchObj);
+ std::string matchStr = matchObj.getDocument().getField("match").getString();
+ if (matchStr.empty()) {
+ // This would only happen if the regex matched an empty string. In this case, even if
+ // the character at startByteIndex matches the regex, we cannot return it since we are
+ // already returing an empty string starting at this index. So we move on to the next
+ // byte index.
+ startByteIndex += getCodePointLength(input[startByteIndex]);
+ ++startCodePointIndex;
+ continue;
+ }
+ // We don't want any overlapping sub-strings. So we move 'startByteIndex' to point to the
+ // byte after 'matchStr'. We move the code point index also correspondingly.
+ startByteIndex += matchStr.size();
+ for (size_t byteIx = 0; byteIx < matchStr.size(); ++startCodePointIndex) {
+ byteIx += getCodePointLength(matchStr[byteIx]);
+ }
+ invariant(startByteIndex > 0 && startCodePointIndex > 0 &&
+ startCodePointIndex <= startByteIndex);
+ } while (static_cast<size_t>(startByteIndex) < input.size());
+ return Value(output);
+}
+
+REGISTER_EXPRESSION(regexFindAll, ExpressionRegexFindAll::parse);
+const char* ExpressionRegexFindAll::getOpName() const {
+ return "$regexFindAll";
+}
+
} // namespace mongo
diff --git a/src/mongo/db/pipeline/expression.h b/src/mongo/db/pipeline/expression.h
index 4aa91a67086..b0949cca3fc 100644
--- a/src/mongo/db/pipeline/expression.h
+++ b/src/mongo/db/pipeline/expression.h
@@ -2093,7 +2093,7 @@ private:
boost::intrusive_ptr<Expression> _onNull;
};
-class ExpressionRegexFind final : public ExpressionFixedArity<ExpressionRegexFind, 1> {
+class ExpressionRegexFind : public ExpressionFixedArity<ExpressionRegexFind, 1> {
public:
explicit ExpressionRegexFind(const boost::intrusive_ptr<ExpressionContext>& expCtx)
: ExpressionFixedArity<ExpressionRegexFind, 1>(expCtx) {}
@@ -2101,4 +2101,13 @@ public:
Value evaluate(const Document& root) const final;
const char* getOpName() const final;
};
+
+class ExpressionRegexFindAll final : public ExpressionFixedArity<ExpressionRegexFindAll, 1> {
+public:
+ explicit ExpressionRegexFindAll(const boost::intrusive_ptr<ExpressionContext>& expCtx)
+ : ExpressionFixedArity<ExpressionRegexFindAll, 1>(expCtx) {}
+
+ Value evaluate(const Document& root) const final;
+ const char* getOpName() const final;
+};
}
diff --git a/src/mongo/db/pipeline/expression_test.cpp b/src/mongo/db/pipeline/expression_test.cpp
index ca254af56c7..c9bc46a2c8b 100644
--- a/src/mongo/db/pipeline/expression_test.cpp
+++ b/src/mongo/db/pipeline/expression_test.cpp
@@ -5950,7 +5950,7 @@ TEST(GetComputedPathsTest, ExpressionMapNotConsideredRenameWithDottedInputPath)
} // namespace GetComputedPathsTest
-namespace ExpressionRegexFindTest {
+namespace ExpressionRegexTest {
TEST(ExpressionRegexFindTest, BasicTest) {
Value input(fromjson("{input: 'asdf', regex: '^as' }"));
@@ -5979,11 +5979,64 @@ TEST(ExpressionRegexFindTest, FailureCase) {
intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest());
ExpressionRegexFind regexF(expCtx);
regexF.addOperand(ExpressionConstant::create(expCtx, input));
- ASSERT_THROWS(regexF.evaluate(Document()), DBException);
+ ASSERT_THROWS_CODE(regexF.evaluate(Document()), DBException, 51105);
}
+TEST(ExpressionRegexFindAllTest, MultipleMatches) {
+ Value input(fromjson("{input: 'a1b2c3', regex: '([a-c][1-3])' }"));
+ std::vector<Value> expectedOut = {Value(fromjson("{match: 'a1', idx:0, captures:['a1']}")),
+ Value(fromjson("{match: 'b2', idx:2, captures:['b2']}")),
+ Value(fromjson("{match: 'c3', idx:4, captures:['c3']}"))};
+ intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest());
+ ExpressionRegexFindAll regexF(expCtx);
+ regexF.addOperand(ExpressionConstant::create(expCtx, input));
+ Value output = regexF.evaluate(Document());
+ ASSERT_VALUE_EQ(output, Value(expectedOut));
+}
+
+TEST(ExpressionRegexFindAllTest, NoMatch) {
+ Value input(fromjson("{input: 'a1b2c3', regex: 'ab' }"));
+ intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest());
+ ExpressionRegexFindAll regexF(expCtx);
+ regexF.addOperand(ExpressionConstant::create(expCtx, input));
+ Value output = regexF.evaluate(Document());
+ ASSERT_VALUE_EQ(output, Value(std::vector<Value>()));
+}
+
+TEST(ExpressionRegexFindAllTest, FailureCase) {
+ Value input(fromjson("{input: 'FirstLine\\nSecondLine', regex: '[0-9'}"));
+ intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest());
+ ExpressionRegexFindAll regexF(expCtx);
+ regexF.addOperand(ExpressionConstant::create(expCtx, input));
+ ASSERT_THROWS_CODE(regexF.evaluate(Document()), DBException, 51111);
+}
+
+TEST(ExpressionRegexFindAllTest, InvalidUTF8InInput) {
+ std::string inputField = "1234 ";
+ // Append an invalid UTF-8 character.
+ inputField += static_cast<char>(0xE5);
+ inputField += " 1234";
+ Value input(fromjson("{input: '" + inputField + "', regex: '[0-9]'}"));
+ intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest());
+ ExpressionRegexFindAll regexF(expCtx);
+ regexF.addOperand(ExpressionConstant::create(expCtx, input));
+ // Verify no match if there is an invalid UTF-8 character in input.
+ ASSERT_VALUE_EQ(regexF.evaluate(Document()), Value(std::vector<Value>()));
+}
+
+TEST(ExpressionRegexFindAllTest, InvalidUTF8InRegex) {
+ std::string regexField = "1234 ";
+ // Append an invalid UTF-8 character.
+ regexField += static_cast<char>(0xE5);
+ Value input(fromjson("{input: '123456', regex: '" + regexField + "'}"));
+ intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest());
+ ExpressionRegexFindAll regexF(expCtx);
+ regexF.addOperand(ExpressionConstant::create(expCtx, input));
+ // Verify that PCRE will error if REGEX is not a valid UTF-8.
+ ASSERT_THROWS_CODE(regexF.evaluate(Document()), DBException, 51111);
+}
-} // namespace ExpressionRegexFindTest
+} // namespace ExpressionRegexTest
class All : public Suite {
public: