diff options
author | Ian Boros <ian.boros@10gen.com> | 2019-01-04 12:39:59 -0500 |
---|---|---|
committer | Ian Boros <ian.boros@10gen.com> | 2019-01-10 16:14:54 -0500 |
commit | a2d97db8fe449d15eb8e275bbf318491781472bf (patch) | |
tree | 38be4b7c20157aa9a26c1eef113268b61e85ca65 /src/mongo/db/pipeline/expression.cpp | |
parent | ee1e46cee281560bf13529c6db75cfb317703780 (diff) | |
download | mongo-r3.4.19-rc0.tar.gz |
SERVER-38070 fix infinite loop in agg expressionr3.4.19-rc0r3.4.19
Diffstat (limited to 'src/mongo/db/pipeline/expression.cpp')
-rw-r--r-- | src/mongo/db/pipeline/expression.cpp | 259 |
1 files changed, 155 insertions, 104 deletions
diff --git a/src/mongo/db/pipeline/expression.cpp b/src/mongo/db/pipeline/expression.cpp index a86eb18f2f8..3a3eb429652 100644 --- a/src/mongo/db/pipeline/expression.cpp +++ b/src/mongo/db/pipeline/expression.cpp @@ -2642,6 +2642,99 @@ const char* ExpressionOr::getOpName() const { return "$or"; } +namespace { +/** + * Helper for ExpressionPow to determine whether base^exp can be represented in a 64 bit int. + * + *'base' and 'exp' are both integers. Assumes 'exp' is in the range [0, 63]. + */ +bool representableAsLong(long long base, long long exp) { + invariant(exp <= 63); + invariant(exp >= 0); + struct MinMax { + long long min; + long long max; + }; + + // Array indices correspond to exponents 0 through 63. The values in each index are the min + // and max bases, respectively, that can be raised to that exponent without overflowing a + // 64-bit int. For max bases, this was computed by solving for b in + // b = (2^63-1)^(1/exp) for exp = [0, 63] and truncating b. To calculate min bases, for even + // exps the equation used was b = (2^63-1)^(1/exp), and for odd exps the equation used was + // b = (-2^63)^(1/exp). Since the magnitude of long min is greater than long max, the + // magnitude of some of the min bases raised to odd exps is greater than the corresponding + // max bases raised to the same exponents. + + static const MinMax kBaseLimits[] = { + {std::numeric_limits<long long>::min(), std::numeric_limits<long long>::max()}, // 0 + {std::numeric_limits<long long>::min(), std::numeric_limits<long long>::max()}, + {-3037000499LL, 3037000499LL}, + {-2097152, 2097151}, + {-55108, 55108}, + {-6208, 6208}, + {-1448, 1448}, + {-512, 511}, + {-234, 234}, + {-128, 127}, + {-78, 78}, // 10 + {-52, 52}, + {-38, 38}, + {-28, 28}, + {-22, 22}, + {-18, 18}, + {-15, 15}, + {-13, 13}, + {-11, 11}, + {-9, 9}, + {-8, 8}, // 20 + {-8, 7}, + {-7, 7}, + {-6, 6}, + {-6, 6}, + {-5, 5}, + {-5, 5}, + {-5, 5}, + {-4, 4}, + {-4, 4}, + {-4, 4}, // 30 + {-4, 4}, + {-3, 3}, + {-3, 3}, + {-3, 3}, + {-3, 3}, + {-3, 3}, + {-3, 3}, + {-3, 3}, + {-3, 3}, + {-2, 2}, // 40 + {-2, 2}, + {-2, 2}, + {-2, 2}, + {-2, 2}, + {-2, 2}, + {-2, 2}, + {-2, 2}, + {-2, 2}, + {-2, 2}, + {-2, 2}, // 50 + {-2, 2}, + {-2, 2}, + {-2, 2}, + {-2, 2}, + {-2, 2}, + {-2, 2}, + {-2, 2}, + {-2, 2}, + {-2, 2}, + {-2, 2}, // 60 + {-2, 2}, + {-2, 2}, + {-2, 1}}; + + return base >= kBaseLimits[exp].min && base <= kBaseLimits[exp].max; +}; +} + /* ----------------------- ExpressionPow ---------------------------- */ intrusive_ptr<Expression> ExpressionPow::create( @@ -2692,119 +2785,77 @@ Value ExpressionPow::evaluateInternal(Variables* vars) const { return Value(std::pow(baseDouble, expDouble)); } - // base and exp are both integers. - - auto representableAsLong = [](long long base, long long exp) { - // If exp is greater than 63 and base is not -1, 0, or 1, the result will overflow. - // If exp is negative and the base is not -1 or 1, the result will be fractional. - if (exp < 0 || exp > 63) { - return std::abs(base) == 1 || base == 0; + // If either number is a long, return a long. If both numbers are ints, then return an int if + // the result fits or a long if it is too big. + const auto formatResult = [baseType, expType](long long res) { + if (baseType == NumberLong || expType == NumberLong) { + return Value(res); } + return Value::createIntOrLong(res); + }; + + const long long baseLong = baseVal.getLong(); + const long long expLong = expVal.getLong(); - struct MinMax { - long long min; - long long max; - }; - - // Array indices correspond to exponents 0 through 63. The values in each index are the min - // and max bases, respectively, that can be raised to that exponent without overflowing a - // 64-bit int. For max bases, this was computed by solving for b in - // b = (2^63-1)^(1/exp) for exp = [0, 63] and truncating b. To calculate min bases, for even - // exps the equation used was b = (2^63-1)^(1/exp), and for odd exps the equation used was - // b = (-2^63)^(1/exp). Since the magnitude of long min is greater than long max, the - // magnitude of some of the min bases raised to odd exps is greater than the corresponding - // max bases raised to the same exponents. - - static const MinMax kBaseLimits[] = { - {std::numeric_limits<long long>::min(), std::numeric_limits<long long>::max()}, // 0 - {std::numeric_limits<long long>::min(), std::numeric_limits<long long>::max()}, - {-3037000499LL, 3037000499LL}, - {-2097152, 2097151}, - {-55108, 55108}, - {-6208, 6208}, - {-1448, 1448}, - {-512, 511}, - {-234, 234}, - {-128, 127}, - {-78, 78}, // 10 - {-52, 52}, - {-38, 38}, - {-28, 28}, - {-22, 22}, - {-18, 18}, - {-15, 15}, - {-13, 13}, - {-11, 11}, - {-9, 9}, - {-8, 8}, // 20 - {-8, 7}, - {-7, 7}, - {-6, 6}, - {-6, 6}, - {-5, 5}, - {-5, 5}, - {-5, 5}, - {-4, 4}, - {-4, 4}, - {-4, 4}, // 30 - {-4, 4}, - {-3, 3}, - {-3, 3}, - {-3, 3}, - {-3, 3}, - {-3, 3}, - {-3, 3}, - {-3, 3}, - {-3, 3}, - {-2, 2}, // 40 - {-2, 2}, - {-2, 2}, - {-2, 2}, - {-2, 2}, - {-2, 2}, - {-2, 2}, - {-2, 2}, - {-2, 2}, - {-2, 2}, - {-2, 2}, // 50 - {-2, 2}, - {-2, 2}, - {-2, 2}, - {-2, 2}, - {-2, 2}, - {-2, 2}, - {-2, 2}, - {-2, 2}, - {-2, 2}, - {-2, 2}, // 60 - {-2, 2}, - {-2, 2}, - {-2, 1}}; - - return base >= kBaseLimits[exp].min && base <= kBaseLimits[exp].max; + // Use this when the result cannot be represented as a long. + const auto computeDoubleResult = [baseLong, expLong]() { + return Value(std::pow(baseLong, expLong)); }; - long long baseLong = baseVal.getLong(); - long long expLong = expVal.getLong(); + // Avoid doing repeated multiplication or using std::pow if the base is -1, 0, or 1. + if (baseLong == 0) { + if (expLong == 0) { + // 0^0 = 1. + return formatResult(1); + } else if (expLong > 0) { + // 0^x where x > 0 is 0. + return formatResult(0); + } - // If the result cannot be represented as a long, return a double. Otherwise if either number - // is a long, return a long. If both numbers are ints, then return an int if the result fits or - // a long if it is too big. + // We should have checked earlier that 0 to a negative power is banned. + MONGO_UNREACHABLE; + } else if (baseLong == 1) { + return formatResult(1); + } else if (baseLong == -1) { + // -1^0 = -1^2 = -1^4 = -1^6 ... = 1 + // -1^1 = -1^3 = -1^5 = -1^7 ... = -1 + return formatResult((expLong % 2 == 0) ? 1 : -1); + } else if (expLong > 63 || expLong < 0) { + // If the base is not 0, 1, or -1 and the exponent is too large, or negative, + // the result cannot be represented as a long. + return computeDoubleResult(); + } + + // It's still possible that the result cannot be represented as a long. If that's the case, + // return a double. if (!representableAsLong(baseLong, expLong)) { - return Value(std::pow(baseLong, expLong)); + return computeDoubleResult(); } - long long result = 1; - // Use repeated multiplication, since pow() casts args to doubles which could result in loss of - // precision if arguments are very large. - for (int i = 0; i < expLong; i++) { - result *= baseLong; - } + // Use repeated multiplication, since pow() casts args to doubles which could result in + // loss of precision if arguments are very large. + const auto computeWithRepeatedMultiplication = [](long long base, long long exp) { + long long result = 1; - if (baseType == NumberLong || expType == NumberLong) { - return Value(result); - } - return Value::createIntOrLong(result); + while (exp > 1) { + if (exp % 2 == 1) { + result *= base; + exp--; + } + // 'exp' is now guaranteed to be even. + base *= base; + exp /= 2; + } + + if (exp) { + invariant(exp == 1); + result *= base; + } + + return result; + }; + + return formatResult(computeWithRepeatedMultiplication(baseLong, expLong)); } REGISTER_EXPRESSION(pow, ExpressionPow::parse); |