diff options
Diffstat (limited to 'src/mongo/db/pipeline')
-rw-r--r-- | src/mongo/db/pipeline/expression.cpp | 259 | ||||
-rw-r--r-- | src/mongo/db/pipeline/expression_test.cpp | 91 |
2 files changed, 246 insertions, 104 deletions
diff --git a/src/mongo/db/pipeline/expression.cpp b/src/mongo/db/pipeline/expression.cpp index a86eb18f2f8..3a3eb429652 100644 --- a/src/mongo/db/pipeline/expression.cpp +++ b/src/mongo/db/pipeline/expression.cpp @@ -2642,6 +2642,99 @@ const char* ExpressionOr::getOpName() const { return "$or"; } +namespace { +/** + * Helper for ExpressionPow to determine whether base^exp can be represented in a 64 bit int. + * + *'base' and 'exp' are both integers. Assumes 'exp' is in the range [0, 63]. + */ +bool representableAsLong(long long base, long long exp) { + invariant(exp <= 63); + invariant(exp >= 0); + struct MinMax { + long long min; + long long max; + }; + + // Array indices correspond to exponents 0 through 63. The values in each index are the min + // and max bases, respectively, that can be raised to that exponent without overflowing a + // 64-bit int. For max bases, this was computed by solving for b in + // b = (2^63-1)^(1/exp) for exp = [0, 63] and truncating b. To calculate min bases, for even + // exps the equation used was b = (2^63-1)^(1/exp), and for odd exps the equation used was + // b = (-2^63)^(1/exp). Since the magnitude of long min is greater than long max, the + // magnitude of some of the min bases raised to odd exps is greater than the corresponding + // max bases raised to the same exponents. + + static const MinMax kBaseLimits[] = { + {std::numeric_limits<long long>::min(), std::numeric_limits<long long>::max()}, // 0 + {std::numeric_limits<long long>::min(), std::numeric_limits<long long>::max()}, + {-3037000499LL, 3037000499LL}, + {-2097152, 2097151}, + {-55108, 55108}, + {-6208, 6208}, + {-1448, 1448}, + {-512, 511}, + {-234, 234}, + {-128, 127}, + {-78, 78}, // 10 + {-52, 52}, + {-38, 38}, + {-28, 28}, + {-22, 22}, + {-18, 18}, + {-15, 15}, + {-13, 13}, + {-11, 11}, + {-9, 9}, + {-8, 8}, // 20 + {-8, 7}, + {-7, 7}, + {-6, 6}, + {-6, 6}, + {-5, 5}, + {-5, 5}, + {-5, 5}, + {-4, 4}, + {-4, 4}, + {-4, 4}, // 30 + {-4, 4}, + {-3, 3}, + {-3, 3}, + {-3, 3}, + {-3, 3}, + {-3, 3}, + {-3, 3}, + {-3, 3}, + {-3, 3}, + {-2, 2}, // 40 + {-2, 2}, + {-2, 2}, + {-2, 2}, + {-2, 2}, + {-2, 2}, + {-2, 2}, + {-2, 2}, + {-2, 2}, + {-2, 2}, + {-2, 2}, // 50 + {-2, 2}, + {-2, 2}, + {-2, 2}, + {-2, 2}, + {-2, 2}, + {-2, 2}, + {-2, 2}, + {-2, 2}, + {-2, 2}, + {-2, 2}, // 60 + {-2, 2}, + {-2, 2}, + {-2, 1}}; + + return base >= kBaseLimits[exp].min && base <= kBaseLimits[exp].max; +}; +} + /* ----------------------- ExpressionPow ---------------------------- */ intrusive_ptr<Expression> ExpressionPow::create( @@ -2692,119 +2785,77 @@ Value ExpressionPow::evaluateInternal(Variables* vars) const { return Value(std::pow(baseDouble, expDouble)); } - // base and exp are both integers. - - auto representableAsLong = [](long long base, long long exp) { - // If exp is greater than 63 and base is not -1, 0, or 1, the result will overflow. - // If exp is negative and the base is not -1 or 1, the result will be fractional. - if (exp < 0 || exp > 63) { - return std::abs(base) == 1 || base == 0; + // If either number is a long, return a long. If both numbers are ints, then return an int if + // the result fits or a long if it is too big. + const auto formatResult = [baseType, expType](long long res) { + if (baseType == NumberLong || expType == NumberLong) { + return Value(res); } + return Value::createIntOrLong(res); + }; + + const long long baseLong = baseVal.getLong(); + const long long expLong = expVal.getLong(); - struct MinMax { - long long min; - long long max; - }; - - // Array indices correspond to exponents 0 through 63. The values in each index are the min - // and max bases, respectively, that can be raised to that exponent without overflowing a - // 64-bit int. For max bases, this was computed by solving for b in - // b = (2^63-1)^(1/exp) for exp = [0, 63] and truncating b. To calculate min bases, for even - // exps the equation used was b = (2^63-1)^(1/exp), and for odd exps the equation used was - // b = (-2^63)^(1/exp). Since the magnitude of long min is greater than long max, the - // magnitude of some of the min bases raised to odd exps is greater than the corresponding - // max bases raised to the same exponents. - - static const MinMax kBaseLimits[] = { - {std::numeric_limits<long long>::min(), std::numeric_limits<long long>::max()}, // 0 - {std::numeric_limits<long long>::min(), std::numeric_limits<long long>::max()}, - {-3037000499LL, 3037000499LL}, - {-2097152, 2097151}, - {-55108, 55108}, - {-6208, 6208}, - {-1448, 1448}, - {-512, 511}, - {-234, 234}, - {-128, 127}, - {-78, 78}, // 10 - {-52, 52}, - {-38, 38}, - {-28, 28}, - {-22, 22}, - {-18, 18}, - {-15, 15}, - {-13, 13}, - {-11, 11}, - {-9, 9}, - {-8, 8}, // 20 - {-8, 7}, - {-7, 7}, - {-6, 6}, - {-6, 6}, - {-5, 5}, - {-5, 5}, - {-5, 5}, - {-4, 4}, - {-4, 4}, - {-4, 4}, // 30 - {-4, 4}, - {-3, 3}, - {-3, 3}, - {-3, 3}, - {-3, 3}, - {-3, 3}, - {-3, 3}, - {-3, 3}, - {-3, 3}, - {-2, 2}, // 40 - {-2, 2}, - {-2, 2}, - {-2, 2}, - {-2, 2}, - {-2, 2}, - {-2, 2}, - {-2, 2}, - {-2, 2}, - {-2, 2}, - {-2, 2}, // 50 - {-2, 2}, - {-2, 2}, - {-2, 2}, - {-2, 2}, - {-2, 2}, - {-2, 2}, - {-2, 2}, - {-2, 2}, - {-2, 2}, - {-2, 2}, // 60 - {-2, 2}, - {-2, 2}, - {-2, 1}}; - - return base >= kBaseLimits[exp].min && base <= kBaseLimits[exp].max; + // Use this when the result cannot be represented as a long. + const auto computeDoubleResult = [baseLong, expLong]() { + return Value(std::pow(baseLong, expLong)); }; - long long baseLong = baseVal.getLong(); - long long expLong = expVal.getLong(); + // Avoid doing repeated multiplication or using std::pow if the base is -1, 0, or 1. + if (baseLong == 0) { + if (expLong == 0) { + // 0^0 = 1. + return formatResult(1); + } else if (expLong > 0) { + // 0^x where x > 0 is 0. + return formatResult(0); + } - // If the result cannot be represented as a long, return a double. Otherwise if either number - // is a long, return a long. If both numbers are ints, then return an int if the result fits or - // a long if it is too big. + // We should have checked earlier that 0 to a negative power is banned. + MONGO_UNREACHABLE; + } else if (baseLong == 1) { + return formatResult(1); + } else if (baseLong == -1) { + // -1^0 = -1^2 = -1^4 = -1^6 ... = 1 + // -1^1 = -1^3 = -1^5 = -1^7 ... = -1 + return formatResult((expLong % 2 == 0) ? 1 : -1); + } else if (expLong > 63 || expLong < 0) { + // If the base is not 0, 1, or -1 and the exponent is too large, or negative, + // the result cannot be represented as a long. + return computeDoubleResult(); + } + + // It's still possible that the result cannot be represented as a long. If that's the case, + // return a double. if (!representableAsLong(baseLong, expLong)) { - return Value(std::pow(baseLong, expLong)); + return computeDoubleResult(); } - long long result = 1; - // Use repeated multiplication, since pow() casts args to doubles which could result in loss of - // precision if arguments are very large. - for (int i = 0; i < expLong; i++) { - result *= baseLong; - } + // Use repeated multiplication, since pow() casts args to doubles which could result in + // loss of precision if arguments are very large. + const auto computeWithRepeatedMultiplication = [](long long base, long long exp) { + long long result = 1; - if (baseType == NumberLong || expType == NumberLong) { - return Value(result); - } - return Value::createIntOrLong(result); + while (exp > 1) { + if (exp % 2 == 1) { + result *= base; + exp--; + } + // 'exp' is now guaranteed to be even. + base *= base; + exp /= 2; + } + + if (exp) { + invariant(exp == 1); + result *= base; + } + + return result; + }; + + return formatResult(computeWithRepeatedMultiplication(baseLong, expLong)); } REGISTER_EXPRESSION(pow, ExpressionPow::parse); diff --git a/src/mongo/db/pipeline/expression_test.cpp b/src/mongo/db/pipeline/expression_test.cpp index 8acf746a696..b6a6b75d966 100644 --- a/src/mongo/db/pipeline/expression_test.cpp +++ b/src/mongo/db/pipeline/expression_test.cpp @@ -2180,6 +2180,97 @@ TEST(ExpressionFromAccumulators, StdDevSamp) { {{}, Value(BSONNULL)}}); } +TEST(ExpressionPowTest, LargeExponentValuesWithBaseOfZero) { + assertExpectedResults( + "$pow", + { + {{Value(0), Value(0)}, Value(1)}, + {{Value(0LL), Value(0LL)}, Value(1LL)}, + + {{Value(0), Value(10)}, Value(0)}, + {{Value(0), Value(10000)}, Value(0)}, + + {{Value(0LL), Value(10)}, Value(0LL)}, + + // $pow may sometimes use a loop to compute a^b, so it's important to check + // that the loop doesn't hang if a large exponent is provided. + {{Value(0LL), Value(std::numeric_limits<long long>::max())}, Value(0LL)}, + }); +} + +TEST(ExpressionPowTest, ThrowsWhenBaseZeroAndExpNegative) { + VariablesIdGenerator idGenerator; + VariablesParseState vps(&idGenerator); + + intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); + + const auto expr = Expression::parseExpression(expCtx, BSON("$pow" << BSON_ARRAY(0 << -5)), vps); + ASSERT_THROWS([&] { expr->evaluate(Document()); }(), AssertionException); + + const auto exprWithLong = + Expression::parseExpression(expCtx, BSON("$pow" << BSON_ARRAY(0LL << -5LL)), vps); + ASSERT_THROWS([&] { expr->evaluate(Document()); }(), AssertionException); +} + +TEST(ExpressionPowTest, LargeExponentValuesWithBaseOfOne) { + assertExpectedResults( + "$pow", + { + {{Value(1), Value(10)}, Value(1)}, + {{Value(1), Value(10LL)}, Value(1LL)}, + {{Value(1), Value(10000LL)}, Value(1LL)}, + + {{Value(1LL), Value(10LL)}, Value(1LL)}, + + // $pow may sometimes use a loop to compute a^b, so it's important to check + // that the loop doesn't hang if a large exponent is provided. + {{Value(1LL), Value(std::numeric_limits<long long>::max())}, Value(1LL)}, + {{Value(1LL), Value(std::numeric_limits<long long>::min())}, Value(1LL)}, + }); +} + +TEST(ExpressionPowTest, LargeExponentValuesWithBaseOfNegativeOne) { + assertExpectedResults("$pow", + { + {{Value(-1), Value(-1)}, Value(-1)}, + {{Value(-1), Value(-2)}, Value(1)}, + {{Value(-1), Value(-3)}, Value(-1)}, + + {{Value(-1LL), Value(0LL)}, Value(1LL)}, + {{Value(-1LL), Value(-1LL)}, Value(-1LL)}, + {{Value(-1LL), Value(-2LL)}, Value(1LL)}, + {{Value(-1LL), Value(-3LL)}, Value(-1LL)}, + {{Value(-1LL), Value(-4LL)}, Value(1LL)}, + {{Value(-1LL), Value(-5LL)}, Value(-1LL)}, + + {{Value(-1LL), Value(-61LL)}, Value(-1LL)}, + {{Value(-1LL), Value(61LL)}, Value(-1LL)}, + + {{Value(-1LL), Value(-62LL)}, Value(1LL)}, + {{Value(-1LL), Value(62LL)}, Value(1LL)}, + + {{Value(-1LL), Value(-101LL)}, Value(-1LL)}, + {{Value(-1LL), Value(-102LL)}, Value(1LL)}, + + // Use a value large enough that will make the test hang for a + // considerable amount of time if a loop is used to compute the + // answer. + {{Value(-1LL), Value(63234673905128LL)}, Value(1LL)}, + {{Value(-1LL), Value(-63234673905128LL)}, Value(1LL)}, + + {{Value(-1LL), Value(63234673905127LL)}, Value(-1LL)}, + {{Value(-1LL), Value(-63234673905127LL)}, Value(-1LL)}, + }); +} + +TEST(ExpressionPowTest, LargeBaseSmallPositiveExponent) { + assertExpectedResults("$pow", + { + {{Value(4294967296LL), Value(1LL)}, Value(4294967296LL)}, + {{Value(4294967296LL), Value(0)}, Value(1LL)}, + }); +} + namespace FieldPath { /** The provided field path does not pass validation. */ |