diff options
Diffstat (limited to 'src/mongo/db/pipeline/expression.cpp')
-rw-r--r-- | src/mongo/db/pipeline/expression.cpp | 252 |
1 files changed, 193 insertions, 59 deletions
diff --git a/src/mongo/db/pipeline/expression.cpp b/src/mongo/db/pipeline/expression.cpp index 9c68924d476..1c976df8f42 100644 --- a/src/mongo/db/pipeline/expression.cpp +++ b/src/mongo/db/pipeline/expression.cpp @@ -41,8 +41,10 @@ #include "mongo/db/pipeline/expression_context.h" #include "mongo/db/pipeline/value.h" #include "mongo/platform/bits.h" +#include "mongo/platform/decimal128.h" #include "mongo/util/mongoutils/str.h" #include "mongo/util/string_map.h" +#include "mongo/util/summation.h" namespace mongo { using Parser = Expression::Parser; @@ -414,6 +416,8 @@ Value ExpressionAbs::evaluateNumericArg(const Value& numericArg) const { BSONType type = numericArg.getType(); if (type == NumberDouble) { return Value(std::abs(numericArg.getDouble())); + } else if (type == NumberDecimal) { + return Value(numericArg.getDecimal().toAbs()); } else { long long num = numericArg.getLong(); uassert(28680, @@ -432,14 +436,12 @@ const char* ExpressionAbs::getOpName() const { /* ------------------------- ExpressionAdd ----------------------------- */ Value ExpressionAdd::evaluateInternal(Variables* vars) const { - /* - We'll try to return the narrowest possible result value. To do that - without creating intermediate Values, do the arithmetic for double - and integral types in parallel, tracking the current narrowest - type. - */ - double doubleTotal = 0; - long long longTotal = 0; + // We'll try to return the narrowest possible result value while avoiding overflow, loss + // of precision due to intermediate rounding or implicit use of decimal types. To do that, + // compute a compensated sum for non-decimal values and a separate decimal sum for decimal + // values, and track the current narrowest type. + DoubleDoubleSummation nonDecimalTotal; + Decimal128 decimalTotal; BSONType totalType = NumberInt; bool haveDate = false; @@ -447,40 +449,64 @@ Value ExpressionAdd::evaluateInternal(Variables* vars) const { for (size_t i = 0; i < n; ++i) { Value val = vpOperand[i]->evaluateInternal(vars); - if (val.numeric()) { - totalType = Value::getWidestNumeric(totalType, val.getType()); - - doubleTotal += val.coerceToDouble(); - longTotal += val.coerceToLong(); - } else if (val.getType() == Date) { - uassert(16612, "only one date allowed in an $add expression", !haveDate); - haveDate = true; - - // We don't manipulate totalType here. - - longTotal += val.getDate(); - doubleTotal += val.getDate(); - } else if (val.nullish()) { - return Value(BSONNULL); - } else { - uasserted(16554, - str::stream() << "$add only supports numeric or date types, not " - << typeName(val.getType())); + switch (val.getType()) { + case NumberDecimal: + decimalTotal = decimalTotal.add(val.getDecimal()); + totalType = NumberDecimal; + break; + case NumberDouble: + nonDecimalTotal.addDouble(val.getDouble()); + if (totalType != NumberDecimal) + totalType = NumberDouble; + break; + case NumberLong: + nonDecimalTotal.addLong(val.getLong()); + if (totalType == NumberInt) + totalType = NumberLong; + break; + case NumberInt: + nonDecimalTotal.addDouble(val.getInt()); + break; + case Date: + uassert(16612, "only one date allowed in an $add expression", !haveDate); + haveDate = true; + nonDecimalTotal.addLong(val.getDate()); + break; + default: + uassert(16554, + str::stream() << "$add only supports numeric or date types, not " + << typeName(val.getType()), + val.nullish()); + return Value(BSONNULL); } } if (haveDate) { - if (totalType == NumberDouble) - longTotal = static_cast<long long>(doubleTotal); + int64_t longTotal; + if (totalType == NumberDecimal) { + longTotal = decimalTotal.add(nonDecimalTotal.getDecimal()).toLong(); + } else { + uassert(ErrorCodes::Overflow, "date overflow in $add", nonDecimalTotal.fitsLong()); + longTotal = nonDecimalTotal.getLong(); + } return Value(Date_t::fromMillisSinceEpoch(longTotal)); - } else if (totalType == NumberLong) { - return Value(longTotal); - } else if (totalType == NumberDouble) { - return Value(doubleTotal); - } else if (totalType == NumberInt) { - return Value::createIntOrLong(longTotal); - } else { - massert(16417, "$add resulted in a non-numeric type", false); + } + switch (totalType) { + case NumberDecimal: + return Value(decimalTotal.add(nonDecimalTotal.getDecimal())); + case NumberLong: + dassert(nonDecimalTotal.isInteger()); + if (nonDecimalTotal.fitsLong()) + return Value(nonDecimalTotal.getLong()); + // Fallthrough. + case NumberInt: + if (nonDecimalTotal.fitsLong()) + return Value::createIntOrLong(nonDecimalTotal.getLong()); + // Fallthrough. + case NumberDouble: + return Value(nonDecimalTotal.getDouble()); + default: + massert(16417, "$add resulted in a non-numeric type", false); } } @@ -677,8 +703,16 @@ const char* ExpressionArrayElemAt::getOpName() const { Value ExpressionCeil::evaluateNumericArg(const Value& numericArg) const { // There's no point in taking the ceiling of integers or longs, it will have no effect. - return numericArg.getType() == NumberDouble ? Value(std::ceil(numericArg.getDouble())) - : numericArg; + switch (numericArg.getType()) { + case NumberDouble: + return Value(std::ceil(numericArg.getDouble())); + case NumberDecimal: + // Round toward the nearest decimal with a zero exponent in the positive direction. + return Value(numericArg.getDecimal().quantize(Decimal128::kNormalizedZero, + Decimal128::kRoundTowardPositive)); + default: + return numericArg; + } } REGISTER_EXPRESSION(ceil, ExpressionCeil::parse); @@ -1203,10 +1237,20 @@ Value ExpressionDivide::evaluateInternal(Variables* vars) const { Value lhs = vpOperand[0]->evaluateInternal(vars); Value rhs = vpOperand[1]->evaluateInternal(vars); + auto assertNonZero = [](bool nonZero) { uassert(16608, "can't $divide by zero", nonZero); }; + if (lhs.numeric() && rhs.numeric()) { + // If, and only if, either side is decimal, return decimal. + if (lhs.getType() == NumberDecimal || rhs.getType() == NumberDecimal) { + Decimal128 numer = lhs.coerceToDecimal(); + Decimal128 denom = rhs.coerceToDecimal(); + assertNonZero(!denom.isZero()); + return Value(numer.divide(denom)); + } + double numer = lhs.coerceToDouble(); double denom = rhs.coerceToDouble(); - uassert(16608, "can't $divide by zero", denom != 0); + assertNonZero(denom != 0.0); return Value(numer / denom); } else if (lhs.nullish() || rhs.nullish()) { @@ -1228,7 +1272,10 @@ const char* ExpressionDivide::getOpName() const { /* ----------------------- ExpressionExp ---------------------------- */ Value ExpressionExp::evaluateNumericArg(const Value& numericArg) const { - // exp() always returns a double since e is a double. + // $exp always returns either a double or a decimal number, as e is irrational. + if (numericArg.getType() == NumberDecimal) + return Value(numericArg.coerceToDecimal().exponential()); + return Value(exp(numericArg.coerceToDouble())); } @@ -1740,8 +1787,16 @@ void ExpressionFilter::addDependencies(DepsTracker* deps, vector<string>* path) Value ExpressionFloor::evaluateNumericArg(const Value& numericArg) const { // There's no point in taking the floor of integers or longs, it will have no effect. - return numericArg.getType() == NumberDouble ? Value(std::floor(numericArg.getDouble())) - : numericArg; + switch (numericArg.getType()) { + case NumberDouble: + return Value(std::floor(numericArg.getDouble())); + case NumberDecimal: + // Round toward the nearest decimal with a zero exponent in the negative direction. + return Value(numericArg.getDecimal().quantize(Decimal128::kNormalizedZero, + Decimal128::kRoundTowardNegative)); + default: + return numericArg; + } } REGISTER_EXPRESSION(floor, ExpressionFloor::parse); @@ -2028,10 +2083,20 @@ Value ExpressionMod::evaluateInternal(Variables* vars) const { BSONType rightType = rhs.getType(); if (lhs.numeric() && rhs.numeric()) { + auto assertNonZero = [](bool isZero) { uassert(16610, "can't $mod by zero", !isZero); }; + + // If either side is decimal, perform the operation in decimal. + if (leftType == NumberDecimal || rightType == NumberDecimal) { + Decimal128 left = lhs.coerceToDecimal(); + Decimal128 right = rhs.coerceToDecimal(); + assertNonZero(right.isZero()); + + return Value(left.modulo(right)); + } + // ensure we aren't modding by 0 double right = rhs.coerceToDouble(); - - uassert(16610, "can't $mod by 0", right != 0); + assertNonZero(right == 0); if (leftType == NumberDouble || (rightType == NumberDouble && !rhs.integral())) { // Need to do fmod. Integer-valued double case is handled below. @@ -2088,6 +2153,8 @@ Value ExpressionMultiply::evaluateInternal(Variables* vars) const { */ double doubleProduct = 1; long long longProduct = 1; + Decimal128 decimalProduct; // This will be initialized on encountering the first decimal. + BSONType productType = NumberInt; const size_t n = vpOperand.size(); @@ -2095,10 +2162,23 @@ Value ExpressionMultiply::evaluateInternal(Variables* vars) const { Value val = vpOperand[i]->evaluateInternal(vars); if (val.numeric()) { + BSONType oldProductType = productType; productType = Value::getWidestNumeric(productType, val.getType()); - - doubleProduct *= val.coerceToDouble(); - longProduct *= val.coerceToLong(); + if (productType == NumberDecimal) { + // On finding the first decimal, convert the partial product to decimal. + if (oldProductType != NumberDecimal) { + decimalProduct = oldProductType == NumberDouble + ? Decimal128(doubleProduct, Decimal128::kRoundTo15Digits) + : Decimal128(static_cast<int64_t>(longProduct)); + } + decimalProduct = decimalProduct.multiply(val.coerceToDecimal()); + } else { + doubleProduct *= val.coerceToDouble(); + if (mongoSignedMultiplyOverflow64(longProduct, val.coerceToLong(), &longProduct)) { + // The 'longProduct' would have overflowed, so we're abandoning it. + productType = NumberDouble; + } + } } else if (val.nullish()) { return Value(BSONNULL); } else { @@ -2114,6 +2194,8 @@ Value ExpressionMultiply::evaluateInternal(Variables* vars) const { return Value(longProduct); else if (productType == NumberInt) return Value::createIntOrLong(longProduct); + else if (productType == NumberDecimal) + return Value(decimalProduct); else massert(16418, "$multiply resulted in a non-numeric type", false); } @@ -2390,6 +2472,12 @@ const char* ExpressionIndexOfCP::getOpName() const { /* ----------------------- ExpressionLn ---------------------------- */ Value ExpressionLn::evaluateNumericArg(const Value& numericArg) const { + if (numericArg.getType() == NumberDecimal) { + Decimal128 argDecimal = numericArg.getDecimal(); + if (argDecimal.isGreater(Decimal128::kNormalizedZero)) + return Value(argDecimal.logarithm()); + // Fall through for error case. + } double argDouble = numericArg.coerceToDouble(); uassert(28766, str::stream() << "$ln's argument must be a positive number, but is " << argDouble, @@ -2417,6 +2505,18 @@ Value ExpressionLog::evaluateInternal(Variables* vars) const { str::stream() << "$log's base must be numeric, not " << typeName(baseVal.getType()), baseVal.numeric()); + if (argVal.getType() == NumberDecimal || baseVal.getType() == NumberDecimal) { + Decimal128 argDecimal = argVal.coerceToDecimal(); + Decimal128 baseDecimal = baseVal.coerceToDecimal(); + + if (argDecimal.isGreater(Decimal128::kNormalizedZero) && + baseDecimal.isNotEqual(Decimal128(1)) && + baseDecimal.isGreater(Decimal128::kNormalizedZero)) { + return Value(argDecimal.logarithm(baseDecimal)); + } + // Fall through for error cases. + } + double argDouble = argVal.coerceToDouble(); double baseDouble = baseVal.coerceToDouble(); uassert(28758, @@ -2437,6 +2537,13 @@ const char* ExpressionLog::getOpName() const { /* ----------------------- ExpressionLog10 ---------------------------- */ Value ExpressionLog10::evaluateNumericArg(const Value& numericArg) const { + if (numericArg.getType() == NumberDecimal) { + Decimal128 argDecimal = numericArg.getDecimal(); + if (argDecimal.isGreater(Decimal128::kNormalizedZero)) + return Value(argDecimal.logarithm(Decimal128(10))); + // Fall through for error case. + } + double argDouble = numericArg.coerceToDouble(); uassert(28761, str::stream() << "$log10's argument must be a positive number, but is " << argDouble, @@ -2672,15 +2779,24 @@ Value ExpressionPow::evaluateInternal(Variables* vars) const { str::stream() << "$pow's exponent must be numeric, not " << typeName(expType), expVal.numeric()); + auto checkNonZeroAndNeg = [](bool isZeroAndNeg) { + uassert(28764, "$pow cannot take a base of 0 and a negative exponent", !isZeroAndNeg); + }; + + // If either argument is decimal, return a decimal. + if (baseType == NumberDecimal || expType == NumberDecimal) { + Decimal128 baseDecimal = baseVal.coerceToDecimal(); + Decimal128 expDecimal = expVal.coerceToDecimal(); + checkNonZeroAndNeg(baseDecimal.isZero() && expDecimal.isNegative()); + return Value(baseDecimal.power(expDecimal)); + } + // pow() will cast args to doubles. double baseDouble = baseVal.coerceToDouble(); double expDouble = expVal.coerceToDouble(); + checkNonZeroAndNeg(baseDouble == 0 && expDouble < 0); - uassert(28764, - "$pow cannot take a base of 0 and a negative exponent", - !(baseDouble == 0 && expDouble < 0)); - - // If either number is a double, return a double. + // If either argument is a double, return a double. if (baseType == NumberDouble || expType == NumberDouble) { return Value(std::pow(baseDouble, expDouble)); } @@ -3414,10 +3530,17 @@ const char* ExpressionSplit::getOpName() const { /* ----------------------- ExpressionSqrt ---------------------------- */ Value ExpressionSqrt::evaluateNumericArg(const Value& numericArg) const { + auto checkArg = [](bool nonNegative) { + uassert(28714, "$sqrt's argument must be greater than or equal to 0", nonNegative); + }; + + if (numericArg.getType() == NumberDecimal) { + Decimal128 argDec = numericArg.getDecimal(); + checkArg(!argDec.isLess(Decimal128::kNormalizedZero)); // NaN returns Nan without error + return Value(argDec.squareRoot()); + } double argDouble = numericArg.coerceToDouble(); - uassert(28714, - "$sqrt's argument must be greater than or equal to 0", - argDouble >= 0 || std::isnan(argDouble)); + checkArg(!(argDouble < 0)); // NaN returns Nan without error return Value(sqrt(argDouble)); } @@ -3638,7 +3761,11 @@ Value ExpressionSubtract::evaluateInternal(Variables* vars) const { BSONType diffType = Value::getWidestNumeric(rhs.getType(), lhs.getType()); - if (diffType == NumberDouble) { + if (diffType == NumberDecimal) { + Decimal128 right = rhs.coerceToDecimal(); + Decimal128 left = lhs.coerceToDecimal(); + return Value(left.subtract(right)); + } else if (diffType == NumberDouble) { double right = rhs.coerceToDouble(); double left = lhs.coerceToDouble(); return Value(left - right); @@ -3838,8 +3965,15 @@ const char* ExpressionToUpper::getOpName() const { Value ExpressionTrunc::evaluateNumericArg(const Value& numericArg) const { // There's no point in truncating integers or longs, it will have no effect. - return numericArg.getType() == NumberDouble ? Value(std::trunc(numericArg.getDouble())) - : numericArg; + switch (numericArg.getType()) { + case NumberDecimal: + return Value(numericArg.getDecimal().quantize(Decimal128::kNormalizedZero, + Decimal128::kRoundTowardZero)); + case NumberDouble: + return Value(std::trunc(numericArg.getDouble())); + default: + return numericArg; + } } REGISTER_EXPRESSION(trunc, ExpressionTrunc::parse); |