summaryrefslogtreecommitdiff
path: root/src/mongo/db
diff options
context:
space:
mode:
authorIan Boros <ian.boros@10gen.com>2018-11-09 17:29:54 -0500
committerIan Boros <ian.boros@10gen.com>2018-11-15 17:02:13 -0500
commit1ded7067e2d1a6161b15e5a462f8cba2d755c9a6 (patch)
tree55e481796c91b5019b2ca0a057e6a5f1ba4abc35 /src/mongo/db
parent88fc32078da9444eb173743ee9ad2af46db74cf2 (diff)
downloadmongo-1ded7067e2d1a6161b15e5a462f8cba2d755c9a6.tar.gz
SERVER-38070 Fix infinite loop in agg expression
Diffstat (limited to 'src/mongo/db')
-rw-r--r--src/mongo/db/pipeline/expression.cpp262
-rw-r--r--src/mongo/db/pipeline/expression_test.cpp75
2 files changed, 226 insertions, 111 deletions
diff --git a/src/mongo/db/pipeline/expression.cpp b/src/mongo/db/pipeline/expression.cpp
index 2d6bd47c282..4646c37363f 100644
--- a/src/mongo/db/pipeline/expression.cpp
+++ b/src/mongo/db/pipeline/expression.cpp
@@ -3247,6 +3247,99 @@ const char* ExpressionOr::getOpName() const {
return "$or";
}
+namespace {
+/**
+ * Helper for ExpressionPow to determine wither base^exp can be represented in a 64 bit int.
+ *
+ *'base' and 'exp' are both integers. Assumes 'exp' is in the range [0, 63].
+ */
+bool representableAsLong(long long base, long long exp) {
+ invariant(exp <= 63);
+ invariant(exp >= 0);
+ struct MinMax {
+ long long min;
+ long long max;
+ };
+
+ // Array indices correspond to exponents 0 through 63. The values in each index are the min
+ // and max bases, respectively, that can be raised to that exponent without overflowing a
+ // 64-bit int. For max bases, this was computed by solving for b in
+ // b = (2^63-1)^(1/exp) for exp = [0, 63] and truncating b. To calculate min bases, for even
+ // exps the equation used was b = (2^63-1)^(1/exp), and for odd exps the equation used was
+ // b = (-2^63)^(1/exp). Since the magnitude of long min is greater than long max, the
+ // magnitude of some of the min bases raised to odd exps is greater than the corresponding
+ // max bases raised to the same exponents.
+
+ static const MinMax kBaseLimits[] = {
+ {std::numeric_limits<long long>::min(), std::numeric_limits<long long>::max()}, // 0
+ {std::numeric_limits<long long>::min(), std::numeric_limits<long long>::max()},
+ {-3037000499LL, 3037000499LL},
+ {-2097152, 2097151},
+ {-55108, 55108},
+ {-6208, 6208},
+ {-1448, 1448},
+ {-512, 511},
+ {-234, 234},
+ {-128, 127},
+ {-78, 78}, // 10
+ {-52, 52},
+ {-38, 38},
+ {-28, 28},
+ {-22, 22},
+ {-18, 18},
+ {-15, 15},
+ {-13, 13},
+ {-11, 11},
+ {-9, 9},
+ {-8, 8}, // 20
+ {-8, 7},
+ {-7, 7},
+ {-6, 6},
+ {-6, 6},
+ {-5, 5},
+ {-5, 5},
+ {-5, 5},
+ {-4, 4},
+ {-4, 4},
+ {-4, 4}, // 30
+ {-4, 4},
+ {-3, 3},
+ {-3, 3},
+ {-3, 3},
+ {-3, 3},
+ {-3, 3},
+ {-3, 3},
+ {-3, 3},
+ {-3, 3},
+ {-2, 2}, // 40
+ {-2, 2},
+ {-2, 2},
+ {-2, 2},
+ {-2, 2},
+ {-2, 2},
+ {-2, 2},
+ {-2, 2},
+ {-2, 2},
+ {-2, 2},
+ {-2, 2}, // 50
+ {-2, 2},
+ {-2, 2},
+ {-2, 2},
+ {-2, 2},
+ {-2, 2},
+ {-2, 2},
+ {-2, 2},
+ {-2, 2},
+ {-2, 2},
+ {-2, 2}, // 60
+ {-2, 2},
+ {-2, 2},
+ {-2, 1}};
+
+ return base >= kBaseLimits[exp].min && base <= kBaseLimits[exp].max;
+};
+}
+
/* ----------------------- ExpressionPow ---------------------------- */
intrusive_ptr<Expression> ExpressionPow::create(
@@ -3297,128 +3390,77 @@ Value ExpressionPow::evaluate(const Document& root) const {
return Value(std::pow(baseDouble, expDouble));
}
- // base and exp are both integers.
-
- auto representableAsLong = [](long long base, long long exp) {
- // If exp is greater than 63 and base is not -1, 0, or 1, the result will overflow.
- // If exp is negative and the base is not -1 or 1, the result will be fractional.
- if (exp < 0 || exp > 63) {
- return std::abs(base) == 1 || base == 0;
+ // If either number is a long, return a long. If both numbers are ints, then return an int if
+ // the result fits or a long if it is too big.
+ const auto formatResult = [baseType, expType](long long res) {
+ if (baseType == NumberLong || expType == NumberLong) {
+ return Value(res);
}
+ return Value::createIntOrLong(res);
+ };
- struct MinMax {
- long long min;
- long long max;
- };
+ const long long baseLong = baseVal.getLong();
+ const long long expLong = expVal.getLong();
- // Array indices correspond to exponents 0 through 63. The values in each index are the min
- // and max bases, respectively, that can be raised to that exponent without overflowing a
- // 64-bit int. For max bases, this was computed by solving for b in
- // b = (2^63-1)^(1/exp) for exp = [0, 63] and truncating b. To calculate min bases, for even
- // exps the equation used was b = (2^63-1)^(1/exp), and for odd exps the equation used was
- // b = (-2^63)^(1/exp). Since the magnitude of long min is greater than long max, the
- // magnitude of some of the min bases raised to odd exps is greater than the corresponding
- // max bases raised to the same exponents.
-
- static const MinMax kBaseLimits[] = {
- {std::numeric_limits<long long>::min(), std::numeric_limits<long long>::max()}, // 0
- {std::numeric_limits<long long>::min(), std::numeric_limits<long long>::max()},
- {-3037000499LL, 3037000499LL},
- {-2097152, 2097151},
- {-55108, 55108},
- {-6208, 6208},
- {-1448, 1448},
- {-512, 511},
- {-234, 234},
- {-128, 127},
- {-78, 78}, // 10
- {-52, 52},
- {-38, 38},
- {-28, 28},
- {-22, 22},
- {-18, 18},
- {-15, 15},
- {-13, 13},
- {-11, 11},
- {-9, 9},
- {-8, 8}, // 20
- {-8, 7},
- {-7, 7},
- {-6, 6},
- {-6, 6},
- {-5, 5},
- {-5, 5},
- {-5, 5},
- {-4, 4},
- {-4, 4},
- {-4, 4}, // 30
- {-4, 4},
- {-3, 3},
- {-3, 3},
- {-3, 3},
- {-3, 3},
- {-3, 3},
- {-3, 3},
- {-3, 3},
- {-3, 3},
- {-2, 2}, // 40
- {-2, 2},
- {-2, 2},
- {-2, 2},
- {-2, 2},
- {-2, 2},
- {-2, 2},
- {-2, 2},
- {-2, 2},
- {-2, 2},
- {-2, 2}, // 50
- {-2, 2},
- {-2, 2},
- {-2, 2},
- {-2, 2},
- {-2, 2},
- {-2, 2},
- {-2, 2},
- {-2, 2},
- {-2, 2},
- {-2, 2}, // 60
- {-2, 2},
- {-2, 2},
- {-2, 1}};
-
- return base >= kBaseLimits[exp].min && base <= kBaseLimits[exp].max;
+ // Use this when the result cannot be represented as a long.
+ const auto computeDoubleResult = [baseLong, expLong]() {
+ return Value(std::pow(baseLong, expLong));
};
- long long baseLong = baseVal.getLong();
- long long expLong = expVal.getLong();
+ // Avoid doing repeated multiplication or using std::pow if the base is -1, 0, or 1.
+ if (baseLong == 0) {
+ if (expLong == 0) {
+ // 0^0 = 1.
+ return formatResult(1);
+ } else if (expLong > 0) {
+ // 0^x where x > 0 is 0.
+ return formatResult(0);
+ }
- // If the result cannot be represented as a long, return a double. Otherwise if either number is
- // a long, return a long. If both numbers are ints, then return an int if the result fits or a
- // long if it is too big.
+ // We should have checked earlier that 0 to a negative power is banned.
+ MONGO_UNREACHABLE;
+ } else if (baseLong == 1) {
+ return formatResult(1);
+ } else if (baseLong == -1) {
+ // -1^0 = -1^2 = -1^4 = -1^6 ... = 1
+ // -1^1 = -1^3 = -1^5 = -1^7 ... = -1
+ return formatResult((expLong % 2 == 0) ? 1 : -1);
+ } else if (expLong > 63 || expLong < 0) {
+ // If the base is not 0, 1, or -1 and the exponent is too large, or negative,
+ // the result cannot be represented as a long.
+ return computeDoubleResult();
+ }
+
+ // It's still possible that the result cannot be represented as a long. If that's the case,
+ // return a double.
if (!representableAsLong(baseLong, expLong)) {
- return Value(std::pow(baseLong, expLong));
+ return computeDoubleResult();
}
- long long result = 1;
+ // Use repeated multiplication, since pow() casts args to doubles which could result in
+ // loss of precision if arguments are very large.
+ const auto computeWithRepeatedMultiplication = [](long long base, long long exp) {
+ long long result = 1;
- // When 'baseLong' == -1 and 'expLong' is < 0 the following for loop will never run because
- // 'expLong' will always be less than 0 so result will always be 1. This is not always correct
- // because the result can potentially be -1. ex: 'baselong' = -1 'expLong' = -5 then result
- // should be -1.
- if (baseLong == -1 && expLong < 0) {
- expLong = expLong % 2 == 0 ? 2 : 1;
- }
+ while (exp > 1) {
+ if (exp % 2 == 1) {
+ result *= base;
+ exp--;
+ }
+ // 'exp' is now guaranteed to be even.
+ base *= base;
+ exp /= 2;
+ }
- // Use repeated multiplication, since pow() casts args to doubles which could result in loss of
- // precision if arguments are very large.
- for (int i = 0; i < expLong; i++) {
- result *= baseLong;
- }
+ if (exp) {
+ invariant(exp == 1);
+ result *= base;
+ }
- if (baseType == NumberLong || expType == NumberLong) {
- return Value(result);
- }
- return Value::createIntOrLong(result);
+ return result;
+ };
+
+ return formatResult(computeWithRepeatedMultiplication(baseLong, expLong));
}
REGISTER_EXPRESSION(pow, ExpressionPow::parse);
diff --git a/src/mongo/db/pipeline/expression_test.cpp b/src/mongo/db/pipeline/expression_test.cpp
index 94f9de8329f..2286374fa7a 100644
--- a/src/mongo/db/pipeline/expression_test.cpp
+++ b/src/mongo/db/pipeline/expression_test.cpp
@@ -2234,7 +2234,54 @@ TEST(ExpressionFromAccumulators, StdDevSamp) {
{{}, Value(BSONNULL)}});
}
-TEST(ExpressionPowTest, NegativeOneRaisedToNegativeOddExponentShouldOutPutNegativeOne) {
+TEST(ExpressionPowTest, LargeExponentValuesWithBaseOfZero) {
+ assertExpectedResults(
+ "$pow",
+ {
+ {{Value(0), Value(0)}, Value(1)},
+ {{Value(0LL), Value(0LL)}, Value(1LL)},
+
+ {{Value(0), Value(10)}, Value(0)},
+ {{Value(0), Value(10000)}, Value(0)},
+
+ {{Value(0LL), Value(10)}, Value(0LL)},
+
+ // $pow may sometimes use a loop to compute a^b, so it's important to check
+ // that the loop doesn't hang if a large exponent is provided.
+ {{Value(0LL), Value(std::numeric_limits<long long>::max())}, Value(0LL)},
+ });
+}
+
+TEST(ExpressionPowTest, ThrowsWhenBaseZeroAndExpNegative) {
+ intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest());
+ VariablesParseState vps = expCtx->variablesParseState;
+
+ const auto expr = Expression::parseExpression(expCtx, BSON("$pow" << BSON_ARRAY(0 << -5)), vps);
+ ASSERT_THROWS([&] { expr->evaluate(Document()); }(), AssertionException);
+
+ const auto exprWithLong =
+ Expression::parseExpression(expCtx, BSON("$pow" << BSON_ARRAY(0LL << -5LL)), vps);
+ ASSERT_THROWS([&] { expr->evaluate(Document()); }(), AssertionException);
+}
+
+TEST(ExpressionPowTest, LargeExponentValuesWithBaseOfOne) {
+ assertExpectedResults(
+ "$pow",
+ {
+ {{Value(1), Value(10)}, Value(1)},
+ {{Value(1), Value(10LL)}, Value(1LL)},
+ {{Value(1), Value(10000LL)}, Value(1LL)},
+
+ {{Value(1LL), Value(10LL)}, Value(1LL)},
+
+ // $pow may sometimes use a loop to compute a^b, so it's important to check
+ // that the loop doesn't hang if a large exponent is provided.
+ {{Value(1LL), Value(std::numeric_limits<long long>::max())}, Value(1LL)},
+ {{Value(1LL), Value(std::numeric_limits<long long>::min())}, Value(1LL)},
+ });
+}
+
+TEST(ExpressionPowTest, LargeExponentValuesWithBaseOfNegativeOne) {
assertExpectedResults("$pow",
{
{{Value(-1), Value(-1)}, Value(-1)},
@@ -2247,6 +2294,32 @@ TEST(ExpressionPowTest, NegativeOneRaisedToNegativeOddExponentShouldOutPutNegati
{{Value(-1LL), Value(-3LL)}, Value(-1LL)},
{{Value(-1LL), Value(-4LL)}, Value(1LL)},
{{Value(-1LL), Value(-5LL)}, Value(-1LL)},
+
+ {{Value(-1LL), Value(-61LL)}, Value(-1LL)},
+ {{Value(-1LL), Value(61LL)}, Value(-1LL)},
+
+ {{Value(-1LL), Value(-62LL)}, Value(1LL)},
+ {{Value(-1LL), Value(62LL)}, Value(1LL)},
+
+ {{Value(-1LL), Value(-101LL)}, Value(-1LL)},
+ {{Value(-1LL), Value(-102LL)}, Value(1LL)},
+
+ // Use a value large enough that will make the test hang for a
+ // considerable amount of time if a loop is used to compute the
+ // answer.
+ {{Value(-1LL), Value(63234673905128LL)}, Value(1LL)},
+ {{Value(-1LL), Value(-63234673905128LL)}, Value(1LL)},
+
+ {{Value(-1LL), Value(63234673905127LL)}, Value(-1LL)},
+ {{Value(-1LL), Value(-63234673905127LL)}, Value(-1LL)},
+ });
+}
+
+TEST(ExpressionPowTest, LargeBaseSmallPositiveExponent) {
+ assertExpectedResults("$pow",
+ {
+ {{Value(4294967296LL), Value(1LL)}, Value(4294967296LL)},
+ {{Value(4294967296LL), Value(0)}, Value(1LL)},
});
}