diff options
author | Ian Boros <ian.boros@10gen.com> | 2018-11-09 17:29:54 -0500 |
---|---|---|
committer | Ian Boros <ian.boros@10gen.com> | 2018-11-15 17:02:13 -0500 |
commit | 1ded7067e2d1a6161b15e5a462f8cba2d755c9a6 (patch) | |
tree | 55e481796c91b5019b2ca0a057e6a5f1ba4abc35 /src/mongo/db | |
parent | 88fc32078da9444eb173743ee9ad2af46db74cf2 (diff) | |
download | mongo-1ded7067e2d1a6161b15e5a462f8cba2d755c9a6.tar.gz |
SERVER-38070 Fix infinite loop in agg expression
Diffstat (limited to 'src/mongo/db')
-rw-r--r-- | src/mongo/db/pipeline/expression.cpp | 262 | ||||
-rw-r--r-- | src/mongo/db/pipeline/expression_test.cpp | 75 |
2 files changed, 226 insertions, 111 deletions
diff --git a/src/mongo/db/pipeline/expression.cpp b/src/mongo/db/pipeline/expression.cpp index 2d6bd47c282..4646c37363f 100644 --- a/src/mongo/db/pipeline/expression.cpp +++ b/src/mongo/db/pipeline/expression.cpp @@ -3247,6 +3247,99 @@ const char* ExpressionOr::getOpName() const { return "$or"; } +namespace { +/** + * Helper for ExpressionPow to determine wither base^exp can be represented in a 64 bit int. + * + *'base' and 'exp' are both integers. Assumes 'exp' is in the range [0, 63]. + */ +bool representableAsLong(long long base, long long exp) { + invariant(exp <= 63); + invariant(exp >= 0); + struct MinMax { + long long min; + long long max; + }; + + // Array indices correspond to exponents 0 through 63. The values in each index are the min + // and max bases, respectively, that can be raised to that exponent without overflowing a + // 64-bit int. For max bases, this was computed by solving for b in + // b = (2^63-1)^(1/exp) for exp = [0, 63] and truncating b. To calculate min bases, for even + // exps the equation used was b = (2^63-1)^(1/exp), and for odd exps the equation used was + // b = (-2^63)^(1/exp). Since the magnitude of long min is greater than long max, the + // magnitude of some of the min bases raised to odd exps is greater than the corresponding + // max bases raised to the same exponents. + + static const MinMax kBaseLimits[] = { + {std::numeric_limits<long long>::min(), std::numeric_limits<long long>::max()}, // 0 + {std::numeric_limits<long long>::min(), std::numeric_limits<long long>::max()}, + {-3037000499LL, 3037000499LL}, + {-2097152, 2097151}, + {-55108, 55108}, + {-6208, 6208}, + {-1448, 1448}, + {-512, 511}, + {-234, 234}, + {-128, 127}, + {-78, 78}, // 10 + {-52, 52}, + {-38, 38}, + {-28, 28}, + {-22, 22}, + {-18, 18}, + {-15, 15}, + {-13, 13}, + {-11, 11}, + {-9, 9}, + {-8, 8}, // 20 + {-8, 7}, + {-7, 7}, + {-6, 6}, + {-6, 6}, + {-5, 5}, + {-5, 5}, + {-5, 5}, + {-4, 4}, + {-4, 4}, + {-4, 4}, // 30 + {-4, 4}, + {-3, 3}, + {-3, 3}, + {-3, 3}, + {-3, 3}, + {-3, 3}, + {-3, 3}, + {-3, 3}, + {-3, 3}, + {-2, 2}, // 40 + {-2, 2}, + {-2, 2}, + {-2, 2}, + {-2, 2}, + {-2, 2}, + {-2, 2}, + {-2, 2}, + {-2, 2}, + {-2, 2}, + {-2, 2}, // 50 + {-2, 2}, + {-2, 2}, + {-2, 2}, + {-2, 2}, + {-2, 2}, + {-2, 2}, + {-2, 2}, + {-2, 2}, + {-2, 2}, + {-2, 2}, // 60 + {-2, 2}, + {-2, 2}, + {-2, 1}}; + + return base >= kBaseLimits[exp].min && base <= kBaseLimits[exp].max; +}; +} + /* ----------------------- ExpressionPow ---------------------------- */ intrusive_ptr<Expression> ExpressionPow::create( @@ -3297,128 +3390,77 @@ Value ExpressionPow::evaluate(const Document& root) const { return Value(std::pow(baseDouble, expDouble)); } - // base and exp are both integers. - - auto representableAsLong = [](long long base, long long exp) { - // If exp is greater than 63 and base is not -1, 0, or 1, the result will overflow. - // If exp is negative and the base is not -1 or 1, the result will be fractional. - if (exp < 0 || exp > 63) { - return std::abs(base) == 1 || base == 0; + // If either number is a long, return a long. If both numbers are ints, then return an int if + // the result fits or a long if it is too big. + const auto formatResult = [baseType, expType](long long res) { + if (baseType == NumberLong || expType == NumberLong) { + return Value(res); } + return Value::createIntOrLong(res); + }; - struct MinMax { - long long min; - long long max; - }; + const long long baseLong = baseVal.getLong(); + const long long expLong = expVal.getLong(); - // Array indices correspond to exponents 0 through 63. The values in each index are the min - // and max bases, respectively, that can be raised to that exponent without overflowing a - // 64-bit int. For max bases, this was computed by solving for b in - // b = (2^63-1)^(1/exp) for exp = [0, 63] and truncating b. To calculate min bases, for even - // exps the equation used was b = (2^63-1)^(1/exp), and for odd exps the equation used was - // b = (-2^63)^(1/exp). Since the magnitude of long min is greater than long max, the - // magnitude of some of the min bases raised to odd exps is greater than the corresponding - // max bases raised to the same exponents. - - static const MinMax kBaseLimits[] = { - {std::numeric_limits<long long>::min(), std::numeric_limits<long long>::max()}, // 0 - {std::numeric_limits<long long>::min(), std::numeric_limits<long long>::max()}, - {-3037000499LL, 3037000499LL}, - {-2097152, 2097151}, - {-55108, 55108}, - {-6208, 6208}, - {-1448, 1448}, - {-512, 511}, - {-234, 234}, - {-128, 127}, - {-78, 78}, // 10 - {-52, 52}, - {-38, 38}, - {-28, 28}, - {-22, 22}, - {-18, 18}, - {-15, 15}, - {-13, 13}, - {-11, 11}, - {-9, 9}, - {-8, 8}, // 20 - {-8, 7}, - {-7, 7}, - {-6, 6}, - {-6, 6}, - {-5, 5}, - {-5, 5}, - {-5, 5}, - {-4, 4}, - {-4, 4}, - {-4, 4}, // 30 - {-4, 4}, - {-3, 3}, - {-3, 3}, - {-3, 3}, - {-3, 3}, - {-3, 3}, - {-3, 3}, - {-3, 3}, - {-3, 3}, - {-2, 2}, // 40 - {-2, 2}, - {-2, 2}, - {-2, 2}, - {-2, 2}, - {-2, 2}, - {-2, 2}, - {-2, 2}, - {-2, 2}, - {-2, 2}, - {-2, 2}, // 50 - {-2, 2}, - {-2, 2}, - {-2, 2}, - {-2, 2}, - {-2, 2}, - {-2, 2}, - {-2, 2}, - {-2, 2}, - {-2, 2}, - {-2, 2}, // 60 - {-2, 2}, - {-2, 2}, - {-2, 1}}; - - return base >= kBaseLimits[exp].min && base <= kBaseLimits[exp].max; + // Use this when the result cannot be represented as a long. + const auto computeDoubleResult = [baseLong, expLong]() { + return Value(std::pow(baseLong, expLong)); }; - long long baseLong = baseVal.getLong(); - long long expLong = expVal.getLong(); + // Avoid doing repeated multiplication or using std::pow if the base is -1, 0, or 1. + if (baseLong == 0) { + if (expLong == 0) { + // 0^0 = 1. + return formatResult(1); + } else if (expLong > 0) { + // 0^x where x > 0 is 0. + return formatResult(0); + } - // If the result cannot be represented as a long, return a double. Otherwise if either number is - // a long, return a long. If both numbers are ints, then return an int if the result fits or a - // long if it is too big. + // We should have checked earlier that 0 to a negative power is banned. + MONGO_UNREACHABLE; + } else if (baseLong == 1) { + return formatResult(1); + } else if (baseLong == -1) { + // -1^0 = -1^2 = -1^4 = -1^6 ... = 1 + // -1^1 = -1^3 = -1^5 = -1^7 ... = -1 + return formatResult((expLong % 2 == 0) ? 1 : -1); + } else if (expLong > 63 || expLong < 0) { + // If the base is not 0, 1, or -1 and the exponent is too large, or negative, + // the result cannot be represented as a long. + return computeDoubleResult(); + } + + // It's still possible that the result cannot be represented as a long. If that's the case, + // return a double. if (!representableAsLong(baseLong, expLong)) { - return Value(std::pow(baseLong, expLong)); + return computeDoubleResult(); } - long long result = 1; + // Use repeated multiplication, since pow() casts args to doubles which could result in + // loss of precision if arguments are very large. + const auto computeWithRepeatedMultiplication = [](long long base, long long exp) { + long long result = 1; - // When 'baseLong' == -1 and 'expLong' is < 0 the following for loop will never run because - // 'expLong' will always be less than 0 so result will always be 1. This is not always correct - // because the result can potentially be -1. ex: 'baselong' = -1 'expLong' = -5 then result - // should be -1. - if (baseLong == -1 && expLong < 0) { - expLong = expLong % 2 == 0 ? 2 : 1; - } + while (exp > 1) { + if (exp % 2 == 1) { + result *= base; + exp--; + } + // 'exp' is now guaranteed to be even. + base *= base; + exp /= 2; + } - // Use repeated multiplication, since pow() casts args to doubles which could result in loss of - // precision if arguments are very large. - for (int i = 0; i < expLong; i++) { - result *= baseLong; - } + if (exp) { + invariant(exp == 1); + result *= base; + } - if (baseType == NumberLong || expType == NumberLong) { - return Value(result); - } - return Value::createIntOrLong(result); + return result; + }; + + return formatResult(computeWithRepeatedMultiplication(baseLong, expLong)); } REGISTER_EXPRESSION(pow, ExpressionPow::parse); diff --git a/src/mongo/db/pipeline/expression_test.cpp b/src/mongo/db/pipeline/expression_test.cpp index 94f9de8329f..2286374fa7a 100644 --- a/src/mongo/db/pipeline/expression_test.cpp +++ b/src/mongo/db/pipeline/expression_test.cpp @@ -2234,7 +2234,54 @@ TEST(ExpressionFromAccumulators, StdDevSamp) { {{}, Value(BSONNULL)}}); } -TEST(ExpressionPowTest, NegativeOneRaisedToNegativeOddExponentShouldOutPutNegativeOne) { +TEST(ExpressionPowTest, LargeExponentValuesWithBaseOfZero) { + assertExpectedResults( + "$pow", + { + {{Value(0), Value(0)}, Value(1)}, + {{Value(0LL), Value(0LL)}, Value(1LL)}, + + {{Value(0), Value(10)}, Value(0)}, + {{Value(0), Value(10000)}, Value(0)}, + + {{Value(0LL), Value(10)}, Value(0LL)}, + + // $pow may sometimes use a loop to compute a^b, so it's important to check + // that the loop doesn't hang if a large exponent is provided. + {{Value(0LL), Value(std::numeric_limits<long long>::max())}, Value(0LL)}, + }); +} + +TEST(ExpressionPowTest, ThrowsWhenBaseZeroAndExpNegative) { + intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); + VariablesParseState vps = expCtx->variablesParseState; + + const auto expr = Expression::parseExpression(expCtx, BSON("$pow" << BSON_ARRAY(0 << -5)), vps); + ASSERT_THROWS([&] { expr->evaluate(Document()); }(), AssertionException); + + const auto exprWithLong = + Expression::parseExpression(expCtx, BSON("$pow" << BSON_ARRAY(0LL << -5LL)), vps); + ASSERT_THROWS([&] { expr->evaluate(Document()); }(), AssertionException); +} + +TEST(ExpressionPowTest, LargeExponentValuesWithBaseOfOne) { + assertExpectedResults( + "$pow", + { + {{Value(1), Value(10)}, Value(1)}, + {{Value(1), Value(10LL)}, Value(1LL)}, + {{Value(1), Value(10000LL)}, Value(1LL)}, + + {{Value(1LL), Value(10LL)}, Value(1LL)}, + + // $pow may sometimes use a loop to compute a^b, so it's important to check + // that the loop doesn't hang if a large exponent is provided. + {{Value(1LL), Value(std::numeric_limits<long long>::max())}, Value(1LL)}, + {{Value(1LL), Value(std::numeric_limits<long long>::min())}, Value(1LL)}, + }); +} + +TEST(ExpressionPowTest, LargeExponentValuesWithBaseOfNegativeOne) { assertExpectedResults("$pow", { {{Value(-1), Value(-1)}, Value(-1)}, @@ -2247,6 +2294,32 @@ TEST(ExpressionPowTest, NegativeOneRaisedToNegativeOddExponentShouldOutPutNegati {{Value(-1LL), Value(-3LL)}, Value(-1LL)}, {{Value(-1LL), Value(-4LL)}, Value(1LL)}, {{Value(-1LL), Value(-5LL)}, Value(-1LL)}, + + {{Value(-1LL), Value(-61LL)}, Value(-1LL)}, + {{Value(-1LL), Value(61LL)}, Value(-1LL)}, + + {{Value(-1LL), Value(-62LL)}, Value(1LL)}, + {{Value(-1LL), Value(62LL)}, Value(1LL)}, + + {{Value(-1LL), Value(-101LL)}, Value(-1LL)}, + {{Value(-1LL), Value(-102LL)}, Value(1LL)}, + + // Use a value large enough that will make the test hang for a + // considerable amount of time if a loop is used to compute the + // answer. + {{Value(-1LL), Value(63234673905128LL)}, Value(1LL)}, + {{Value(-1LL), Value(-63234673905128LL)}, Value(1LL)}, + + {{Value(-1LL), Value(63234673905127LL)}, Value(-1LL)}, + {{Value(-1LL), Value(-63234673905127LL)}, Value(-1LL)}, + }); +} + +TEST(ExpressionPowTest, LargeBaseSmallPositiveExponent) { + assertExpectedResults("$pow", + { + {{Value(4294967296LL), Value(1LL)}, Value(4294967296LL)}, + {{Value(4294967296LL), Value(0)}, Value(1LL)}, }); } |