diff options
author | Benjamin Murphy <benjamin_murphy@me.com> | 2016-04-13 16:44:53 -0400 |
---|---|---|
committer | Benjamin Murphy <benjamin_murphy@me.com> | 2016-04-29 17:17:43 -0400 |
commit | 7ae631410d8ffe71c74f96d5ab5dd408764b7858 (patch) | |
tree | 870200b436e3da7e9bc9e1d8c2c347fcc3c36db6 /src/mongo/db/pipeline/expression.cpp | |
parent | d87ad6adb45a98f52cb78fd2460ee888482edf95 (diff) | |
download | mongo-7ae631410d8ffe71c74f96d5ab5dd408764b7858.tar.gz |
SERVER-8951 Aggregation now supports the indexOfArray, indexOfBytes, and indexOfCP expressions.
Diffstat (limited to 'src/mongo/db/pipeline/expression.cpp')
-rw-r--r-- | src/mongo/db/pipeline/expression.cpp | 299 |
1 files changed, 248 insertions, 51 deletions
diff --git a/src/mongo/db/pipeline/expression.cpp b/src/mongo/db/pipeline/expression.cpp index f462052d856..8024ec39cf2 100644 --- a/src/mongo/db/pipeline/expression.cpp +++ b/src/mongo/db/pipeline/expression.cpp @@ -359,6 +359,44 @@ intrusive_ptr<Expression> Expression::parseOperand(BSONElement exprElement, } } +namespace { +/** + * UTF-8 multi-byte code points consist of one leading byte of the form 11xxxxxx, and potentially + * many continuation bytes of the form 10xxxxxx. This method checks whether 'charByte' is a + * continuation byte. + */ +bool isContinuationByte(char charByte) { + return (charByte & 0xc0) == 0x80; +} + +/** + * UTF-8 multi-byte code points consist of one leading byte of the form 11xxxxxx, and potentially + * many continuation bytes of the form 10xxxxxx. This method checks whether 'charByte' is a leading + * byte. + */ +bool isLeadingByte(char charByte) { + return (charByte & 0xc0) == 0xc0; +} + +/** + * UTF-8 single-byte code points are of the form 0xxxxxxx. This method checks whether 'charByte' is + * a single-byte code point. + */ +bool isSingleByte(char charByte) { + return (charByte & 0x80) == 0x0; +} + +size_t getCodePointLength(char charByte) { + if (isSingleByte(charByte)) { + return 1; + } + + invariant(isLeadingByte(charByte)); + + // In UTF-8, the number of leading ones is the number of bytes the code point takes up. + return countLeadingZeros64(~(uint64_t(charByte) << (64 - 8))); +} +} // namespace /* ----------------------- ExpressionAbs ---------------------------- */ @@ -2118,6 +2156,216 @@ const char* ExpressionIn::getOpName() const { return "$in"; } +/* ----------------------- ExpressionIndexOfArray ------------------ */ + +namespace { + +void uassertIfNotIntegralAndNonNegative(Value val, + StringData expressionName, + StringData argumentName) { + uassert(40096, + str::stream() << expressionName << "requires an integral " << argumentName + << ", found a value of type: " << typeName(val.getType()) + << ", with value: " << val.toString(), + val.integral()); + uassert(40097, + str::stream() << expressionName << " requires a nonnegative " << argumentName + << ", found: " << val.toString(), + val.coerceToInt() >= 0); +} + +} // namespace + +Value ExpressionIndexOfArray::evaluateInternal(Variables* vars) const { + Value arrayArg = vpOperand[0]->evaluateInternal(vars); + + if (arrayArg.nullish()) { + return Value(BSONNULL); + } + + uassert(40090, + str::stream() << "$indexOfArray requires an array as a first argument, found: " + << typeName(arrayArg.getType()), + arrayArg.isArray()); + + std::vector<Value> array = arrayArg.getArray(); + + Value searchItem = vpOperand[1]->evaluateInternal(vars); + + size_t startIndex = 0; + if (vpOperand.size() > 2) { + Value startIndexArg = vpOperand[2]->evaluateInternal(vars); + uassertIfNotIntegralAndNonNegative(startIndexArg, getOpName(), "starting index"); + startIndex = static_cast<size_t>(startIndexArg.coerceToInt()); + } + + size_t endIndex = array.size(); + if (vpOperand.size() > 3) { + Value endIndexArg = vpOperand[3]->evaluateInternal(vars); + uassertIfNotIntegralAndNonNegative(endIndexArg, getOpName(), "ending index"); + // Don't let 'endIndex' exceed the length of the array. + endIndex = std::min(array.size(), static_cast<size_t>(endIndexArg.coerceToInt())); + } + + for (size_t i = startIndex; i < endIndex; i++) { + if (array[i] == searchItem) { + return Value(static_cast<int>(i)); + } + } + + return Value(-1); +} + +REGISTER_EXPRESSION(indexOfArray, ExpressionIndexOfArray::parse); +const char* ExpressionIndexOfArray::getOpName() const { + return "$indexOfArray"; +} + +/* ----------------------- ExpressionIndexOfBytes ------------------ */ + +namespace { + +bool stringHasTokenAtIndex(size_t index, const std::string& input, const std::string& token) { + if (token.size() + index > input.size()) { + return false; + } + return input.compare(index, token.size(), token) == 0; +} + +} // namespace + +Value ExpressionIndexOfBytes::evaluateInternal(Variables* vars) const { + Value stringArg = vpOperand[0]->evaluateInternal(vars); + + if (stringArg.nullish()) { + return Value(BSONNULL); + } + + uassert(40091, + str::stream() << "$indexOfBytes requires a string as the first argument, found: " + << typeName(stringArg.getType()), + stringArg.getType() == String); + const std::string& input = stringArg.getString(); + + Value tokenArg = vpOperand[1]->evaluateInternal(vars); + uassert(40092, + str::stream() << "$indexOfBytes requires a string as the second argument, found: " + << typeName(tokenArg.getType()), + tokenArg.getType() == String); + const std::string& token = tokenArg.getString(); + + size_t startIndex = 0; + if (vpOperand.size() > 2) { + Value startIndexArg = vpOperand[2]->evaluateInternal(vars); + uassertIfNotIntegralAndNonNegative(startIndexArg, getOpName(), "starting index"); + startIndex = static_cast<size_t>(startIndexArg.coerceToInt()); + } + + size_t endIndex = input.size(); + if (vpOperand.size() > 3) { + Value endIndexArg = vpOperand[3]->evaluateInternal(vars); + uassertIfNotIntegralAndNonNegative(endIndexArg, getOpName(), "ending index"); + // Don't let 'endIndex' exceed the length of the string. + endIndex = std::min(input.size(), static_cast<size_t>(endIndexArg.coerceToInt())); + } + + if (startIndex > input.length() || endIndex < startIndex) { + return Value(-1); + } + + size_t position = input.substr(0, endIndex).find(token, startIndex); + if (position == std::string::npos) { + return Value(-1); + } + + return Value(static_cast<int>(position)); +} + +REGISTER_EXPRESSION(indexOfBytes, ExpressionIndexOfBytes::parse); +const char* ExpressionIndexOfBytes::getOpName() const { + return "$indexOfBytes"; +} + +/* ----------------------- ExpressionIndexOfCP --------------------- */ + +Value ExpressionIndexOfCP::evaluateInternal(Variables* vars) const { + Value stringArg = vpOperand[0]->evaluateInternal(vars); + + if (stringArg.nullish()) { + return Value(BSONNULL); + } + + uassert(40093, + str::stream() << "$indexOfCP requires a string as the first argument, found: " + << typeName(stringArg.getType()), + stringArg.getType() == String); + const std::string& input = stringArg.getString(); + + Value tokenArg = vpOperand[1]->evaluateInternal(vars); + uassert(40094, + str::stream() << "$indexOfCP requires a string as the second argument, found: " + << typeName(tokenArg.getType()), + tokenArg.getType() == String); + const std::string& token = tokenArg.getString(); + + size_t startCodePointIndex = 0; + if (vpOperand.size() > 2) { + Value startIndexArg = vpOperand[2]->evaluateInternal(vars); + uassertIfNotIntegralAndNonNegative(startIndexArg, getOpName(), "starting index"); + startCodePointIndex = static_cast<size_t>(startIndexArg.coerceToInt()); + } + + // Compute the length (in code points) of the input, and convert 'startCodePointIndex' to a byte + // index. + size_t codePointLength = 0; + size_t startByteIndex = 0; + for (size_t byteIx = 0; byteIx < input.size(); ++codePointLength) { + if (codePointLength == startCodePointIndex) { + // We have determined the byte at which our search will start. + startByteIndex = byteIx; + } + + uassert( + 40095, "$indexOfCP found bad UTF-8 in the input", !isContinuationByte(input[byteIx])); + byteIx += getCodePointLength(input[byteIx]); + } + + size_t endCodePointIndex = codePointLength; + if (vpOperand.size() > 3) { + Value endIndexArg = vpOperand[3]->evaluateInternal(vars); + uassertIfNotIntegralAndNonNegative(endIndexArg, getOpName(), "ending index"); + + // Don't let 'endCodePointIndex' exceed the number of code points in the string. + endCodePointIndex = + std::min(codePointLength, static_cast<size_t>(endIndexArg.coerceToInt())); + } + + if (startByteIndex == 0 && input.empty() && token.empty()) { + // If we are finding the index of "" in the string "", the below loop will not loop, so we + // need a special case for this. + return Value(0); + } + + // We must keep track of which byte, and which code point, we are examining, being careful not + // to overflow either the length of the string or the ending code point. + + size_t currentCodePointIndex = startCodePointIndex; + for (size_t byteIx = startByteIndex; currentCodePointIndex < endCodePointIndex; + ++currentCodePointIndex) { + if (stringHasTokenAtIndex(byteIx, input, token)) { + return Value(static_cast<int>(currentCodePointIndex)); + } + byteIx += getCodePointLength(input[byteIx]); + } + + return Value(-1); +} + +REGISTER_EXPRESSION(indexOfCP, ExpressionIndexOfCP::parse); +const char* ExpressionIndexOfCP::getOpName() const { + return "$indexOfCP"; +} + /* ----------------------- ExpressionLn ---------------------------- */ Value ExpressionLn::evaluateNumericArg(const Value& numericArg) const { @@ -3071,18 +3319,6 @@ const char* ExpressionSize::getOpName() const { /* ----------------------- ExpressionSplit --------------------------- */ -namespace { - -bool stringHasTokenAtIndex(size_t index, const std::string& input, const std::string& token) { - if (token.size() + index > input.size()) { - return false; - } - - return input.compare(index, token.size(), token) == 0; -} - -} // namespace - Value ExpressionSplit::evaluateInternal(Variables* vars) const { Value inputArg = vpOperand[0]->evaluateInternal(vars); Value separatorArg = vpOperand[1]->evaluateInternal(vars); @@ -3174,45 +3410,6 @@ const char* ExpressionStrcasecmp::getOpName() const { return "$strcasecmp"; } -namespace { -/** - * UTF-8 multi-byte code points consist of one leading byte of the form 11xxxxxx, and potentially - * many continuation bytes of the form 10xxxxxx. This method checks whether 'charByte' is a - * continuation byte. - */ -bool isContinuationByte(char charByte) { - return (charByte & 0xc0) == 0x80; -} - -/** - * UTF-8 multi-byte code points consist of one leading byte of the form 11xxxxxx, and potentially - * many continuation bytes of the form 10xxxxxx. This method checks whether 'charByte' is a leading - * byte. - */ -bool isLeadingByte(char charByte) { - return (charByte & 0xc0) == 0xc0; -} - -/** - * UTF-8 single-byte code points are of the form 0xxxxxxx. This method checks whether 'charByte' is - * a single-byte code point. - */ -bool isSingleByte(char charByte) { - return (charByte & 0x80) == 0x0; -} - -size_t getCodePointLength(char charByte) { - if (isSingleByte(charByte)) { - return 1; - } - - invariant(isLeadingByte(charByte)); - - // In UTF-8, the number of leading ones is the number of bytes the code point takes up. - return countLeadingZeros64(~(uint64_t(charByte) << (64 - 8))); -} -} // namespace - /* ----------------------- ExpressionSubstrBytes ---------------------------- */ Value ExpressionSubstrBytes::evaluateInternal(Variables* vars) const { |