summaryrefslogtreecommitdiff
path: root/src/mongo/db/pipeline/expression.cpp
diff options
context:
space:
mode:
authorBenjamin Murphy <benjamin_murphy@me.com>2016-04-13 16:44:53 -0400
committerBenjamin Murphy <benjamin_murphy@me.com>2016-04-29 17:17:43 -0400
commit7ae631410d8ffe71c74f96d5ab5dd408764b7858 (patch)
tree870200b436e3da7e9bc9e1d8c2c347fcc3c36db6 /src/mongo/db/pipeline/expression.cpp
parentd87ad6adb45a98f52cb78fd2460ee888482edf95 (diff)
downloadmongo-7ae631410d8ffe71c74f96d5ab5dd408764b7858.tar.gz
SERVER-8951 Aggregation now supports the indexOfArray, indexOfBytes, and indexOfCP expressions.
Diffstat (limited to 'src/mongo/db/pipeline/expression.cpp')
-rw-r--r--src/mongo/db/pipeline/expression.cpp299
1 files changed, 248 insertions, 51 deletions
diff --git a/src/mongo/db/pipeline/expression.cpp b/src/mongo/db/pipeline/expression.cpp
index f462052d856..8024ec39cf2 100644
--- a/src/mongo/db/pipeline/expression.cpp
+++ b/src/mongo/db/pipeline/expression.cpp
@@ -359,6 +359,44 @@ intrusive_ptr<Expression> Expression::parseOperand(BSONElement exprElement,
}
}
+namespace {
+/**
+ * UTF-8 multi-byte code points consist of one leading byte of the form 11xxxxxx, and potentially
+ * many continuation bytes of the form 10xxxxxx. This method checks whether 'charByte' is a
+ * continuation byte.
+ */
+bool isContinuationByte(char charByte) {
+ return (charByte & 0xc0) == 0x80;
+}
+
+/**
+ * UTF-8 multi-byte code points consist of one leading byte of the form 11xxxxxx, and potentially
+ * many continuation bytes of the form 10xxxxxx. This method checks whether 'charByte' is a leading
+ * byte.
+ */
+bool isLeadingByte(char charByte) {
+ return (charByte & 0xc0) == 0xc0;
+}
+
+/**
+ * UTF-8 single-byte code points are of the form 0xxxxxxx. This method checks whether 'charByte' is
+ * a single-byte code point.
+ */
+bool isSingleByte(char charByte) {
+ return (charByte & 0x80) == 0x0;
+}
+
+size_t getCodePointLength(char charByte) {
+ if (isSingleByte(charByte)) {
+ return 1;
+ }
+
+ invariant(isLeadingByte(charByte));
+
+ // In UTF-8, the number of leading ones is the number of bytes the code point takes up.
+ return countLeadingZeros64(~(uint64_t(charByte) << (64 - 8)));
+}
+} // namespace
/* ----------------------- ExpressionAbs ---------------------------- */
@@ -2118,6 +2156,216 @@ const char* ExpressionIn::getOpName() const {
return "$in";
}
+/* ----------------------- ExpressionIndexOfArray ------------------ */
+
+namespace {
+
+void uassertIfNotIntegralAndNonNegative(Value val,
+ StringData expressionName,
+ StringData argumentName) {
+ uassert(40096,
+ str::stream() << expressionName << "requires an integral " << argumentName
+ << ", found a value of type: " << typeName(val.getType())
+ << ", with value: " << val.toString(),
+ val.integral());
+ uassert(40097,
+ str::stream() << expressionName << " requires a nonnegative " << argumentName
+ << ", found: " << val.toString(),
+ val.coerceToInt() >= 0);
+}
+
+} // namespace
+
+Value ExpressionIndexOfArray::evaluateInternal(Variables* vars) const {
+ Value arrayArg = vpOperand[0]->evaluateInternal(vars);
+
+ if (arrayArg.nullish()) {
+ return Value(BSONNULL);
+ }
+
+ uassert(40090,
+ str::stream() << "$indexOfArray requires an array as a first argument, found: "
+ << typeName(arrayArg.getType()),
+ arrayArg.isArray());
+
+ std::vector<Value> array = arrayArg.getArray();
+
+ Value searchItem = vpOperand[1]->evaluateInternal(vars);
+
+ size_t startIndex = 0;
+ if (vpOperand.size() > 2) {
+ Value startIndexArg = vpOperand[2]->evaluateInternal(vars);
+ uassertIfNotIntegralAndNonNegative(startIndexArg, getOpName(), "starting index");
+ startIndex = static_cast<size_t>(startIndexArg.coerceToInt());
+ }
+
+ size_t endIndex = array.size();
+ if (vpOperand.size() > 3) {
+ Value endIndexArg = vpOperand[3]->evaluateInternal(vars);
+ uassertIfNotIntegralAndNonNegative(endIndexArg, getOpName(), "ending index");
+ // Don't let 'endIndex' exceed the length of the array.
+ endIndex = std::min(array.size(), static_cast<size_t>(endIndexArg.coerceToInt()));
+ }
+
+ for (size_t i = startIndex; i < endIndex; i++) {
+ if (array[i] == searchItem) {
+ return Value(static_cast<int>(i));
+ }
+ }
+
+ return Value(-1);
+}
+
+REGISTER_EXPRESSION(indexOfArray, ExpressionIndexOfArray::parse);
+const char* ExpressionIndexOfArray::getOpName() const {
+ return "$indexOfArray";
+}
+
+/* ----------------------- ExpressionIndexOfBytes ------------------ */
+
+namespace {
+
+bool stringHasTokenAtIndex(size_t index, const std::string& input, const std::string& token) {
+ if (token.size() + index > input.size()) {
+ return false;
+ }
+ return input.compare(index, token.size(), token) == 0;
+}
+
+} // namespace
+
+Value ExpressionIndexOfBytes::evaluateInternal(Variables* vars) const {
+ Value stringArg = vpOperand[0]->evaluateInternal(vars);
+
+ if (stringArg.nullish()) {
+ return Value(BSONNULL);
+ }
+
+ uassert(40091,
+ str::stream() << "$indexOfBytes requires a string as the first argument, found: "
+ << typeName(stringArg.getType()),
+ stringArg.getType() == String);
+ const std::string& input = stringArg.getString();
+
+ Value tokenArg = vpOperand[1]->evaluateInternal(vars);
+ uassert(40092,
+ str::stream() << "$indexOfBytes requires a string as the second argument, found: "
+ << typeName(tokenArg.getType()),
+ tokenArg.getType() == String);
+ const std::string& token = tokenArg.getString();
+
+ size_t startIndex = 0;
+ if (vpOperand.size() > 2) {
+ Value startIndexArg = vpOperand[2]->evaluateInternal(vars);
+ uassertIfNotIntegralAndNonNegative(startIndexArg, getOpName(), "starting index");
+ startIndex = static_cast<size_t>(startIndexArg.coerceToInt());
+ }
+
+ size_t endIndex = input.size();
+ if (vpOperand.size() > 3) {
+ Value endIndexArg = vpOperand[3]->evaluateInternal(vars);
+ uassertIfNotIntegralAndNonNegative(endIndexArg, getOpName(), "ending index");
+ // Don't let 'endIndex' exceed the length of the string.
+ endIndex = std::min(input.size(), static_cast<size_t>(endIndexArg.coerceToInt()));
+ }
+
+ if (startIndex > input.length() || endIndex < startIndex) {
+ return Value(-1);
+ }
+
+ size_t position = input.substr(0, endIndex).find(token, startIndex);
+ if (position == std::string::npos) {
+ return Value(-1);
+ }
+
+ return Value(static_cast<int>(position));
+}
+
+REGISTER_EXPRESSION(indexOfBytes, ExpressionIndexOfBytes::parse);
+const char* ExpressionIndexOfBytes::getOpName() const {
+ return "$indexOfBytes";
+}
+
+/* ----------------------- ExpressionIndexOfCP --------------------- */
+
+Value ExpressionIndexOfCP::evaluateInternal(Variables* vars) const {
+ Value stringArg = vpOperand[0]->evaluateInternal(vars);
+
+ if (stringArg.nullish()) {
+ return Value(BSONNULL);
+ }
+
+ uassert(40093,
+ str::stream() << "$indexOfCP requires a string as the first argument, found: "
+ << typeName(stringArg.getType()),
+ stringArg.getType() == String);
+ const std::string& input = stringArg.getString();
+
+ Value tokenArg = vpOperand[1]->evaluateInternal(vars);
+ uassert(40094,
+ str::stream() << "$indexOfCP requires a string as the second argument, found: "
+ << typeName(tokenArg.getType()),
+ tokenArg.getType() == String);
+ const std::string& token = tokenArg.getString();
+
+ size_t startCodePointIndex = 0;
+ if (vpOperand.size() > 2) {
+ Value startIndexArg = vpOperand[2]->evaluateInternal(vars);
+ uassertIfNotIntegralAndNonNegative(startIndexArg, getOpName(), "starting index");
+ startCodePointIndex = static_cast<size_t>(startIndexArg.coerceToInt());
+ }
+
+ // Compute the length (in code points) of the input, and convert 'startCodePointIndex' to a byte
+ // index.
+ size_t codePointLength = 0;
+ size_t startByteIndex = 0;
+ for (size_t byteIx = 0; byteIx < input.size(); ++codePointLength) {
+ if (codePointLength == startCodePointIndex) {
+ // We have determined the byte at which our search will start.
+ startByteIndex = byteIx;
+ }
+
+ uassert(
+ 40095, "$indexOfCP found bad UTF-8 in the input", !isContinuationByte(input[byteIx]));
+ byteIx += getCodePointLength(input[byteIx]);
+ }
+
+ size_t endCodePointIndex = codePointLength;
+ if (vpOperand.size() > 3) {
+ Value endIndexArg = vpOperand[3]->evaluateInternal(vars);
+ uassertIfNotIntegralAndNonNegative(endIndexArg, getOpName(), "ending index");
+
+ // Don't let 'endCodePointIndex' exceed the number of code points in the string.
+ endCodePointIndex =
+ std::min(codePointLength, static_cast<size_t>(endIndexArg.coerceToInt()));
+ }
+
+ if (startByteIndex == 0 && input.empty() && token.empty()) {
+ // If we are finding the index of "" in the string "", the below loop will not loop, so we
+ // need a special case for this.
+ return Value(0);
+ }
+
+ // We must keep track of which byte, and which code point, we are examining, being careful not
+ // to overflow either the length of the string or the ending code point.
+
+ size_t currentCodePointIndex = startCodePointIndex;
+ for (size_t byteIx = startByteIndex; currentCodePointIndex < endCodePointIndex;
+ ++currentCodePointIndex) {
+ if (stringHasTokenAtIndex(byteIx, input, token)) {
+ return Value(static_cast<int>(currentCodePointIndex));
+ }
+ byteIx += getCodePointLength(input[byteIx]);
+ }
+
+ return Value(-1);
+}
+
+REGISTER_EXPRESSION(indexOfCP, ExpressionIndexOfCP::parse);
+const char* ExpressionIndexOfCP::getOpName() const {
+ return "$indexOfCP";
+}
+
/* ----------------------- ExpressionLn ---------------------------- */
Value ExpressionLn::evaluateNumericArg(const Value& numericArg) const {
@@ -3071,18 +3319,6 @@ const char* ExpressionSize::getOpName() const {
/* ----------------------- ExpressionSplit --------------------------- */
-namespace {
-
-bool stringHasTokenAtIndex(size_t index, const std::string& input, const std::string& token) {
- if (token.size() + index > input.size()) {
- return false;
- }
-
- return input.compare(index, token.size(), token) == 0;
-}
-
-} // namespace
-
Value ExpressionSplit::evaluateInternal(Variables* vars) const {
Value inputArg = vpOperand[0]->evaluateInternal(vars);
Value separatorArg = vpOperand[1]->evaluateInternal(vars);
@@ -3174,45 +3410,6 @@ const char* ExpressionStrcasecmp::getOpName() const {
return "$strcasecmp";
}
-namespace {
-/**
- * UTF-8 multi-byte code points consist of one leading byte of the form 11xxxxxx, and potentially
- * many continuation bytes of the form 10xxxxxx. This method checks whether 'charByte' is a
- * continuation byte.
- */
-bool isContinuationByte(char charByte) {
- return (charByte & 0xc0) == 0x80;
-}
-
-/**
- * UTF-8 multi-byte code points consist of one leading byte of the form 11xxxxxx, and potentially
- * many continuation bytes of the form 10xxxxxx. This method checks whether 'charByte' is a leading
- * byte.
- */
-bool isLeadingByte(char charByte) {
- return (charByte & 0xc0) == 0xc0;
-}
-
-/**
- * UTF-8 single-byte code points are of the form 0xxxxxxx. This method checks whether 'charByte' is
- * a single-byte code point.
- */
-bool isSingleByte(char charByte) {
- return (charByte & 0x80) == 0x0;
-}
-
-size_t getCodePointLength(char charByte) {
- if (isSingleByte(charByte)) {
- return 1;
- }
-
- invariant(isLeadingByte(charByte));
-
- // In UTF-8, the number of leading ones is the number of bytes the code point takes up.
- return countLeadingZeros64(~(uint64_t(charByte) << (64 - 8)));
-}
-} // namespace
-
/* ----------------------- ExpressionSubstrBytes ---------------------------- */
Value ExpressionSubstrBytes::evaluateInternal(Variables* vars) const {