diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/mongo/db/pipeline/SConscript | 1 | ||||
-rw-r--r-- | src/mongo/db/pipeline/accumulator.h | 17 | ||||
-rw-r--r-- | src/mongo/db/pipeline/accumulator_avg.cpp | 80 | ||||
-rw-r--r-- | src/mongo/db/pipeline/accumulator_sum.cpp | 105 | ||||
-rw-r--r-- | src/mongo/db/pipeline/accumulator_test.cpp | 89 | ||||
-rw-r--r-- | src/mongo/db/pipeline/expression.cpp | 252 | ||||
-rw-r--r-- | src/mongo/db/pipeline/expression_test.cpp | 59 | ||||
-rw-r--r-- | src/mongo/db/pipeline/value.cpp | 9 | ||||
-rw-r--r-- | src/mongo/db/pipeline/value.h | 4 |
9 files changed, 455 insertions, 161 deletions
diff --git a/src/mongo/db/pipeline/SConscript b/src/mongo/db/pipeline/SConscript index eef0b7ecada..8525401fb72 100644 --- a/src/mongo/db/pipeline/SConscript +++ b/src/mongo/db/pipeline/SConscript @@ -91,6 +91,7 @@ env.Library( ], LIBDEPS=[ 'document_value', + '$BUILD_DIR/mongo/util/summation', 'expression', 'field_path', ] diff --git a/src/mongo/db/pipeline/accumulator.h b/src/mongo/db/pipeline/accumulator.h index 7dd604b82b4..eb09459f942 100644 --- a/src/mongo/db/pipeline/accumulator.h +++ b/src/mongo/db/pipeline/accumulator.h @@ -38,6 +38,7 @@ #include "mongo/bson/bsontypes.h" #include "mongo/db/pipeline/value.h" #include "mongo/stdx/functional.h" +#include "mongo/util/summation.h" namespace mongo { /** @@ -194,9 +195,9 @@ public: } private: - BSONType totalType; - long long longTotal; - double doubleTotal; + BSONType totalType = NumberInt; + DoubleDoubleSummation nonDecimalTotal; + Decimal128 decimalTotal; }; @@ -268,7 +269,15 @@ public: static boost::intrusive_ptr<Accumulator> create(); private: - double _total; + /** + * The total of all values is partitioned between those that are decimals, and those that are + * not decimals, so the decimal total needs to add the non-decimal. + */ + Decimal128 _getDecimalTotal() const; + + bool _isDecimal; + DoubleDoubleSummation _nonDecimalTotal; + Decimal128 _decimalTotal; long long _count; }; diff --git a/src/mongo/db/pipeline/accumulator_avg.cpp b/src/mongo/db/pipeline/accumulator_avg.cpp index ed11d81ecc0..f9faf8d359a 100644 --- a/src/mongo/db/pipeline/accumulator_avg.cpp +++ b/src/mongo/db/pipeline/accumulator_avg.cpp @@ -33,6 +33,7 @@ #include "mongo/db/pipeline/expression.h" #include "mongo/db/pipeline/expression_context.h" #include "mongo/db/pipeline/value.h" +#include "mongo/platform/decimal128.h" namespace mongo { @@ -47,48 +48,85 @@ const char* AccumulatorAvg::getOpName() const { namespace { const char subTotalName[] = "subTotal"; +const char subTotalErrorName[] = "subTotalError"; // Used for extra precision const char countName[] = "count"; -} +} // namespace void AccumulatorAvg::processInternal(const Value& input, bool merging) { - if (!merging) { - // non numeric types have no impact on average - if (!input.numeric()) - return; - - _total += input.getDouble(); - _count += 1; - } else { - // We expect an object that contains both a subtotal and a count. - // This is what getValue(true) produced below. + if (merging) { + // We expect an object that contains both a subtotal and a count. Additionally there may + // be an error value, that allows for additional precision. + // 'input' is what getValue(true) produced below. verify(input.getType() == Object); - _total += input[subTotalName].getDouble(); - _count += input[countName].getLong(); + // We're recursively adding the subtotal to get the proper type treatment, but this only + // increments the count by one, so adjust the count afterwards. Similarly for 'error'. + processInternal(input[subTotalName], false); + _count += input[countName].getLong() - 1; + Value error = input[subTotalErrorName]; + if (!error.missing()) { + processInternal(error, false); + _count--; // The error correction only adjusts the total, not the number of items. + } + return; + } + + switch (input.getType()) { + case NumberDecimal: + _decimalTotal = _decimalTotal.add(input.getDecimal()); + _isDecimal = true; + break; + case NumberLong: + // Avoid summation using double as that loses precision. + _nonDecimalTotal.addLong(input.getLong()); + break; + case NumberInt: + case NumberDouble: + _nonDecimalTotal.addDouble(input.getDouble()); + break; + default: + dassert(!input.numeric()); + return; } + _count++; } intrusive_ptr<Accumulator> AccumulatorAvg::create() { return new AccumulatorAvg(); } +Decimal128 AccumulatorAvg::_getDecimalTotal() const { + return _decimalTotal.add(_nonDecimalTotal.getDecimal()); +} + Value AccumulatorAvg::getValue(bool toBeMerged) const { - if (!toBeMerged) { - if (_count == 0) - return Value(BSONNULL); + if (toBeMerged) { + if (_isDecimal) + return Value(Document{{subTotalName, _getDecimalTotal()}, {countName, _count}}); - return Value(_total / static_cast<double>(_count)); - } else { - return Value(DOC(subTotalName << _total << countName << _count)); + double total, error; + std::tie(total, error) = _nonDecimalTotal.getDoubleDouble(); + return Value( + Document{{subTotalName, total}, {countName, _count}, {subTotalErrorName, error}}); } + + if (_count == 0) + return Value(BSONNULL); + + if (_isDecimal) + return Value(_getDecimalTotal().divide(Decimal128(static_cast<int64_t>(_count)))); + + return Value(_nonDecimalTotal.getDouble() / static_cast<double>(_count)); } -AccumulatorAvg::AccumulatorAvg() : _total(0), _count(0) { +AccumulatorAvg::AccumulatorAvg() : _isDecimal(false), _count(0) { // This is a fixed size Accumulator so we never need to update this _memUsageBytes = sizeof(*this); } void AccumulatorAvg::reset() { - _total = 0; + _isDecimal = false; + _nonDecimalTotal = {}; + _decimalTotal = {}; _count = 0; } } diff --git a/src/mongo/db/pipeline/accumulator_sum.cpp b/src/mongo/db/pipeline/accumulator_sum.cpp index c064fe52f04..1255c3fde8b 100644 --- a/src/mongo/db/pipeline/accumulator_sum.cpp +++ b/src/mongo/db/pipeline/accumulator_sum.cpp @@ -28,9 +28,13 @@ #include "mongo/platform/basic.h" +#include <cmath> +#include <limits> + #include "mongo/db/pipeline/accumulator.h" #include "mongo/db/pipeline/expression.h" #include "mongo/db/pipeline/value.h" +#include "mongo/util/summation.h" namespace mongo { @@ -43,24 +47,38 @@ const char* AccumulatorSum::getOpName() const { return "$sum"; } +namespace { +const char subTotalName[] = "subTotal"; +const char subTotalErrorName[] = "subTotalError"; // Used for extra precision. +} // namespace + + void AccumulatorSum::processInternal(const Value& input, bool merging) { - // do nothing with non numeric types - if (!input.numeric()) + if (!input.numeric()) { + if (merging && input.getType() == Object) { + // Process merge document, see getValue() below. + nonDecimalTotal.addDouble( + input[subTotalName].getDouble()); // Sum without adjusting type. + processInternal(input[subTotalErrorName], false); // Sum adjusting for type of error. + } return; + } - // upgrade to the widest type required to hold the result + // Upgrade to the widest type required to hold the result. totalType = Value::getWidestNumeric(totalType, input.getType()); - - if (totalType == NumberInt || totalType == NumberLong) { - long long v = input.coerceToLong(); - longTotal += v; - doubleTotal += v; - } else if (totalType == NumberDouble) { - double v = input.coerceToDouble(); - doubleTotal += v; - } else { - // non numerics should have returned above so we should never get here - verify(false); + switch (input.getType()) { + case NumberInt: + case NumberLong: + nonDecimalTotal.addLong(input.coerceToLong()); + break; + case NumberDouble: + nonDecimalTotal.addDouble(input.getDouble()); + break; + case NumberDecimal: + decimalTotal = decimalTotal.add(input.coerceToDecimal()); + break; + default: + MONGO_UNREACHABLE; } } @@ -69,25 +87,58 @@ intrusive_ptr<Accumulator> AccumulatorSum::create() { } Value AccumulatorSum::getValue(bool toBeMerged) const { - if (totalType == NumberLong) { - return Value(longTotal); - } else if (totalType == NumberDouble) { - return Value(doubleTotal); - } else if (totalType == NumberInt) { - return Value::createIntOrLong(longTotal); - } else { - massert(16000, "$sum resulted in a non-numeric type", false); + switch (totalType) { + case NumberInt: + if (nonDecimalTotal.fitsLong()) + return Value::createIntOrLong(nonDecimalTotal.getLong()); + // Fallthrough. + case NumberLong: + if (nonDecimalTotal.fitsLong()) + return Value(nonDecimalTotal.getLong()); + if (toBeMerged) { + // The value was too large for a NumberLong, so output a document with two values + // adding up to the desired total. Older MongoDB versions used to ignore signed + // integer overflow and cause undefined behavior, that in practice resulted in + // values that would wrap around modulo 2**64. Now an older mongos with a newer + // mongod will yield an error that $sum resulted in a non-numeric type, which is + // OK for this case. Output the error using the totalType, so in the future we can + // determine the correct totalType for the sum. For the error to exceed 2**63, + // more than 2**53 integers would have to be summed, which is impossible. + double total; + double error; + std::tie(total, error) = nonDecimalTotal.getDoubleDouble(); + long long llerror = static_cast<long long>(error); + return Value(DOC(subTotalName << total << subTotalErrorName << llerror)); + } + // Sum doesn't fit a NumberLong, so return a NumberDouble instead. + return Value(nonDecimalTotal.getDouble()); + + case NumberDouble: + return Value(nonDecimalTotal.getDouble()); + case NumberDecimal: { + double sum, error; + std::tie(sum, error) = nonDecimalTotal.getDoubleDouble(); + Decimal128 total; // zero + if (sum != 0) { + total = total.add(Decimal128(sum, Decimal128::kRoundTo34Digits)); + total = total.add(Decimal128(error, Decimal128::kRoundTo34Digits)); + } + total = total.add(decimalTotal); + return Value(total); + } + default: + MONGO_UNREACHABLE; } } -AccumulatorSum::AccumulatorSum() : totalType(NumberInt), longTotal(0), doubleTotal(0) { - // This is a fixed size Accumulator so we never need to update this +AccumulatorSum::AccumulatorSum() { + // This is a fixed size Accumulator so we never need to update this. _memUsageBytes = sizeof(*this); } void AccumulatorSum::reset() { totalType = NumberInt; - longTotal = 0; - doubleTotal = 0; -} + nonDecimalTotal = {}; + decimalTotal = {}; } +} // namespace mongo diff --git a/src/mongo/db/pipeline/accumulator_test.cpp b/src/mongo/db/pipeline/accumulator_test.cpp index 928eb2e868e..e354d71d8f9 100644 --- a/src/mongo/db/pipeline/accumulator_test.cpp +++ b/src/mongo/db/pipeline/accumulator_test.cpp @@ -96,38 +96,55 @@ static void assertExpectedResults( TEST(Accumulators, Avg) { assertExpectedResults( "$avg", - {// No documents evaluated. - {{}, Value(BSONNULL)}, - - // One int value is converted to double. - {{Value(3)}, Value(3.0)}, - // One long value is converted to double. - {{Value(-4LL)}, Value(-4.0)}, - // One double value. - {{Value(22.6)}, Value(22.6)}, - - // Averaging two ints. - {{Value(10), Value(11)}, Value(10.5)}, - // Averaging two longs. - {{Value(10LL), Value(11LL)}, Value(10.5)}, - // Averaging two doubles. - {{Value(10.0), Value(11.0)}, Value(10.5)}, - - // The average of an int and a double is a double. - {{Value(10), Value(11.0)}, Value(10.5)}, - // The average of a long and a double is a double. - {{Value(5LL), Value(1.0)}, Value(3.0)}, - // The average of an int and a long is a double. - {{Value(5), Value(3LL)}, Value(4.0)}, - // Averaging an int, long, and double. - {{Value(1), Value(2LL), Value(6.0)}, Value(3.0)}, - - // Unlike $sum, two ints do not overflow in the 'total' portion of the average. - {{Value(numeric_limits<int>::max()), Value(numeric_limits<int>::max())}, - Value(static_cast<double>(numeric_limits<int>::max()))}, - // Two longs do overflow in the 'total' portion of the average. - {{Value(numeric_limits<long long>::max()), Value(numeric_limits<long long>::max())}, - Value(static_cast<double>(numeric_limits<long long>::max()))}}); + { + // No documents evaluated. + {{}, Value(BSONNULL)}, + + // One int value is converted to double. + {{Value(3)}, Value(3.0)}, + // One long value is converted to double. + {{Value(-4LL)}, Value(-4.0)}, + // One double value. + {{Value(22.6)}, Value(22.6)}, + + // Averaging two ints. + {{Value(10), Value(11)}, Value(10.5)}, + // Averaging two longs. + {{Value(10LL), Value(11LL)}, Value(10.5)}, + // Averaging two doubles. + {{Value(10.0), Value(11.0)}, Value(10.5)}, + + // The average of an int and a double is a double. + {{Value(10), Value(11.0)}, Value(10.5)}, + // The average of a long and a double is a double. + {{Value(5LL), Value(1.0)}, Value(3.0)}, + // The average of an int and a long is a double. + {{Value(5), Value(3LL)}, Value(4.0)}, + // Averaging an int, long, and double. + {{Value(1), Value(2LL), Value(6.0)}, Value(3.0)}, + + // Unlike $sum, two ints do not overflow in the 'total' portion of the average. + {{Value(numeric_limits<int>::max()), Value(numeric_limits<int>::max())}, + Value(static_cast<double>(numeric_limits<int>::max()))}, + // Two longs do overflow in the 'total' portion of the average. + {{Value(numeric_limits<long long>::max()), Value(numeric_limits<long long>::max())}, + Value(static_cast<double>(numeric_limits<long long>::max()))}, + + // Averaging two decimals. + {{Value(Decimal128("-1234567890.1234567889")), + Value(Decimal128("-1234567890.1234567891"))}, + Value(Decimal128("-1234567890.1234567890"))}, + + // Averaging two longs and a decimal results in an accurate decimal result. + {{Value(1234567890123456788LL), + Value(1234567890123456789LL), + Value(Decimal128("1234567890123456790.037037036703702"))}, + Value(Decimal128("1234567890123456789.012345678901234"))}, + + // Averaging a double and a decimal + {{Value(1.0E22), Value(Decimal128("9999999999999999999999.9999999999"))}, + Value(Decimal128("9999999999999999999999.99999999995"))}, + }); } TEST(Accumulators, First) { @@ -249,12 +266,12 @@ TEST(Accumulators, Sum) { // An int and a double do not trigger an int overflow. {{Value(numeric_limits<int>::max()), Value(1.0)}, Value(static_cast<long long>(numeric_limits<int>::max()) + 1.0)}, - // An int and a long overflow. + // An int and a long overflow into a double. {{Value(1), Value(numeric_limits<long long>::max())}, - Value(numeric_limits<long long>::min())}, - // Two longs overflow. + Value(-static_cast<double>(numeric_limits<long long>::min()))}, + // Two longs overflow into a double. {{Value(numeric_limits<long long>::max()), Value(numeric_limits<long long>::max())}, - Value(-2LL)}, + Value(static_cast<double>(numeric_limits<long long>::max()) * 2)}, // A long and a double do not trigger a long overflow. {{Value(numeric_limits<long long>::max()), Value(1.0)}, Value(numeric_limits<long long>::max() + 1.0)}, diff --git a/src/mongo/db/pipeline/expression.cpp b/src/mongo/db/pipeline/expression.cpp index 9c68924d476..1c976df8f42 100644 --- a/src/mongo/db/pipeline/expression.cpp +++ b/src/mongo/db/pipeline/expression.cpp @@ -41,8 +41,10 @@ #include "mongo/db/pipeline/expression_context.h" #include "mongo/db/pipeline/value.h" #include "mongo/platform/bits.h" +#include "mongo/platform/decimal128.h" #include "mongo/util/mongoutils/str.h" #include "mongo/util/string_map.h" +#include "mongo/util/summation.h" namespace mongo { using Parser = Expression::Parser; @@ -414,6 +416,8 @@ Value ExpressionAbs::evaluateNumericArg(const Value& numericArg) const { BSONType type = numericArg.getType(); if (type == NumberDouble) { return Value(std::abs(numericArg.getDouble())); + } else if (type == NumberDecimal) { + return Value(numericArg.getDecimal().toAbs()); } else { long long num = numericArg.getLong(); uassert(28680, @@ -432,14 +436,12 @@ const char* ExpressionAbs::getOpName() const { /* ------------------------- ExpressionAdd ----------------------------- */ Value ExpressionAdd::evaluateInternal(Variables* vars) const { - /* - We'll try to return the narrowest possible result value. To do that - without creating intermediate Values, do the arithmetic for double - and integral types in parallel, tracking the current narrowest - type. - */ - double doubleTotal = 0; - long long longTotal = 0; + // We'll try to return the narrowest possible result value while avoiding overflow, loss + // of precision due to intermediate rounding or implicit use of decimal types. To do that, + // compute a compensated sum for non-decimal values and a separate decimal sum for decimal + // values, and track the current narrowest type. + DoubleDoubleSummation nonDecimalTotal; + Decimal128 decimalTotal; BSONType totalType = NumberInt; bool haveDate = false; @@ -447,40 +449,64 @@ Value ExpressionAdd::evaluateInternal(Variables* vars) const { for (size_t i = 0; i < n; ++i) { Value val = vpOperand[i]->evaluateInternal(vars); - if (val.numeric()) { - totalType = Value::getWidestNumeric(totalType, val.getType()); - - doubleTotal += val.coerceToDouble(); - longTotal += val.coerceToLong(); - } else if (val.getType() == Date) { - uassert(16612, "only one date allowed in an $add expression", !haveDate); - haveDate = true; - - // We don't manipulate totalType here. - - longTotal += val.getDate(); - doubleTotal += val.getDate(); - } else if (val.nullish()) { - return Value(BSONNULL); - } else { - uasserted(16554, - str::stream() << "$add only supports numeric or date types, not " - << typeName(val.getType())); + switch (val.getType()) { + case NumberDecimal: + decimalTotal = decimalTotal.add(val.getDecimal()); + totalType = NumberDecimal; + break; + case NumberDouble: + nonDecimalTotal.addDouble(val.getDouble()); + if (totalType != NumberDecimal) + totalType = NumberDouble; + break; + case NumberLong: + nonDecimalTotal.addLong(val.getLong()); + if (totalType == NumberInt) + totalType = NumberLong; + break; + case NumberInt: + nonDecimalTotal.addDouble(val.getInt()); + break; + case Date: + uassert(16612, "only one date allowed in an $add expression", !haveDate); + haveDate = true; + nonDecimalTotal.addLong(val.getDate()); + break; + default: + uassert(16554, + str::stream() << "$add only supports numeric or date types, not " + << typeName(val.getType()), + val.nullish()); + return Value(BSONNULL); } } if (haveDate) { - if (totalType == NumberDouble) - longTotal = static_cast<long long>(doubleTotal); + int64_t longTotal; + if (totalType == NumberDecimal) { + longTotal = decimalTotal.add(nonDecimalTotal.getDecimal()).toLong(); + } else { + uassert(ErrorCodes::Overflow, "date overflow in $add", nonDecimalTotal.fitsLong()); + longTotal = nonDecimalTotal.getLong(); + } return Value(Date_t::fromMillisSinceEpoch(longTotal)); - } else if (totalType == NumberLong) { - return Value(longTotal); - } else if (totalType == NumberDouble) { - return Value(doubleTotal); - } else if (totalType == NumberInt) { - return Value::createIntOrLong(longTotal); - } else { - massert(16417, "$add resulted in a non-numeric type", false); + } + switch (totalType) { + case NumberDecimal: + return Value(decimalTotal.add(nonDecimalTotal.getDecimal())); + case NumberLong: + dassert(nonDecimalTotal.isInteger()); + if (nonDecimalTotal.fitsLong()) + return Value(nonDecimalTotal.getLong()); + // Fallthrough. + case NumberInt: + if (nonDecimalTotal.fitsLong()) + return Value::createIntOrLong(nonDecimalTotal.getLong()); + // Fallthrough. + case NumberDouble: + return Value(nonDecimalTotal.getDouble()); + default: + massert(16417, "$add resulted in a non-numeric type", false); } } @@ -677,8 +703,16 @@ const char* ExpressionArrayElemAt::getOpName() const { Value ExpressionCeil::evaluateNumericArg(const Value& numericArg) const { // There's no point in taking the ceiling of integers or longs, it will have no effect. - return numericArg.getType() == NumberDouble ? Value(std::ceil(numericArg.getDouble())) - : numericArg; + switch (numericArg.getType()) { + case NumberDouble: + return Value(std::ceil(numericArg.getDouble())); + case NumberDecimal: + // Round toward the nearest decimal with a zero exponent in the positive direction. + return Value(numericArg.getDecimal().quantize(Decimal128::kNormalizedZero, + Decimal128::kRoundTowardPositive)); + default: + return numericArg; + } } REGISTER_EXPRESSION(ceil, ExpressionCeil::parse); @@ -1203,10 +1237,20 @@ Value ExpressionDivide::evaluateInternal(Variables* vars) const { Value lhs = vpOperand[0]->evaluateInternal(vars); Value rhs = vpOperand[1]->evaluateInternal(vars); + auto assertNonZero = [](bool nonZero) { uassert(16608, "can't $divide by zero", nonZero); }; + if (lhs.numeric() && rhs.numeric()) { + // If, and only if, either side is decimal, return decimal. + if (lhs.getType() == NumberDecimal || rhs.getType() == NumberDecimal) { + Decimal128 numer = lhs.coerceToDecimal(); + Decimal128 denom = rhs.coerceToDecimal(); + assertNonZero(!denom.isZero()); + return Value(numer.divide(denom)); + } + double numer = lhs.coerceToDouble(); double denom = rhs.coerceToDouble(); - uassert(16608, "can't $divide by zero", denom != 0); + assertNonZero(denom != 0.0); return Value(numer / denom); } else if (lhs.nullish() || rhs.nullish()) { @@ -1228,7 +1272,10 @@ const char* ExpressionDivide::getOpName() const { /* ----------------------- ExpressionExp ---------------------------- */ Value ExpressionExp::evaluateNumericArg(const Value& numericArg) const { - // exp() always returns a double since e is a double. + // $exp always returns either a double or a decimal number, as e is irrational. + if (numericArg.getType() == NumberDecimal) + return Value(numericArg.coerceToDecimal().exponential()); + return Value(exp(numericArg.coerceToDouble())); } @@ -1740,8 +1787,16 @@ void ExpressionFilter::addDependencies(DepsTracker* deps, vector<string>* path) Value ExpressionFloor::evaluateNumericArg(const Value& numericArg) const { // There's no point in taking the floor of integers or longs, it will have no effect. - return numericArg.getType() == NumberDouble ? Value(std::floor(numericArg.getDouble())) - : numericArg; + switch (numericArg.getType()) { + case NumberDouble: + return Value(std::floor(numericArg.getDouble())); + case NumberDecimal: + // Round toward the nearest decimal with a zero exponent in the negative direction. + return Value(numericArg.getDecimal().quantize(Decimal128::kNormalizedZero, + Decimal128::kRoundTowardNegative)); + default: + return numericArg; + } } REGISTER_EXPRESSION(floor, ExpressionFloor::parse); @@ -2028,10 +2083,20 @@ Value ExpressionMod::evaluateInternal(Variables* vars) const { BSONType rightType = rhs.getType(); if (lhs.numeric() && rhs.numeric()) { + auto assertNonZero = [](bool isZero) { uassert(16610, "can't $mod by zero", !isZero); }; + + // If either side is decimal, perform the operation in decimal. + if (leftType == NumberDecimal || rightType == NumberDecimal) { + Decimal128 left = lhs.coerceToDecimal(); + Decimal128 right = rhs.coerceToDecimal(); + assertNonZero(right.isZero()); + + return Value(left.modulo(right)); + } + // ensure we aren't modding by 0 double right = rhs.coerceToDouble(); - - uassert(16610, "can't $mod by 0", right != 0); + assertNonZero(right == 0); if (leftType == NumberDouble || (rightType == NumberDouble && !rhs.integral())) { // Need to do fmod. Integer-valued double case is handled below. @@ -2088,6 +2153,8 @@ Value ExpressionMultiply::evaluateInternal(Variables* vars) const { */ double doubleProduct = 1; long long longProduct = 1; + Decimal128 decimalProduct; // This will be initialized on encountering the first decimal. + BSONType productType = NumberInt; const size_t n = vpOperand.size(); @@ -2095,10 +2162,23 @@ Value ExpressionMultiply::evaluateInternal(Variables* vars) const { Value val = vpOperand[i]->evaluateInternal(vars); if (val.numeric()) { + BSONType oldProductType = productType; productType = Value::getWidestNumeric(productType, val.getType()); - - doubleProduct *= val.coerceToDouble(); - longProduct *= val.coerceToLong(); + if (productType == NumberDecimal) { + // On finding the first decimal, convert the partial product to decimal. + if (oldProductType != NumberDecimal) { + decimalProduct = oldProductType == NumberDouble + ? Decimal128(doubleProduct, Decimal128::kRoundTo15Digits) + : Decimal128(static_cast<int64_t>(longProduct)); + } + decimalProduct = decimalProduct.multiply(val.coerceToDecimal()); + } else { + doubleProduct *= val.coerceToDouble(); + if (mongoSignedMultiplyOverflow64(longProduct, val.coerceToLong(), &longProduct)) { + // The 'longProduct' would have overflowed, so we're abandoning it. + productType = NumberDouble; + } + } } else if (val.nullish()) { return Value(BSONNULL); } else { @@ -2114,6 +2194,8 @@ Value ExpressionMultiply::evaluateInternal(Variables* vars) const { return Value(longProduct); else if (productType == NumberInt) return Value::createIntOrLong(longProduct); + else if (productType == NumberDecimal) + return Value(decimalProduct); else massert(16418, "$multiply resulted in a non-numeric type", false); } @@ -2390,6 +2472,12 @@ const char* ExpressionIndexOfCP::getOpName() const { /* ----------------------- ExpressionLn ---------------------------- */ Value ExpressionLn::evaluateNumericArg(const Value& numericArg) const { + if (numericArg.getType() == NumberDecimal) { + Decimal128 argDecimal = numericArg.getDecimal(); + if (argDecimal.isGreater(Decimal128::kNormalizedZero)) + return Value(argDecimal.logarithm()); + // Fall through for error case. + } double argDouble = numericArg.coerceToDouble(); uassert(28766, str::stream() << "$ln's argument must be a positive number, but is " << argDouble, @@ -2417,6 +2505,18 @@ Value ExpressionLog::evaluateInternal(Variables* vars) const { str::stream() << "$log's base must be numeric, not " << typeName(baseVal.getType()), baseVal.numeric()); + if (argVal.getType() == NumberDecimal || baseVal.getType() == NumberDecimal) { + Decimal128 argDecimal = argVal.coerceToDecimal(); + Decimal128 baseDecimal = baseVal.coerceToDecimal(); + + if (argDecimal.isGreater(Decimal128::kNormalizedZero) && + baseDecimal.isNotEqual(Decimal128(1)) && + baseDecimal.isGreater(Decimal128::kNormalizedZero)) { + return Value(argDecimal.logarithm(baseDecimal)); + } + // Fall through for error cases. + } + double argDouble = argVal.coerceToDouble(); double baseDouble = baseVal.coerceToDouble(); uassert(28758, @@ -2437,6 +2537,13 @@ const char* ExpressionLog::getOpName() const { /* ----------------------- ExpressionLog10 ---------------------------- */ Value ExpressionLog10::evaluateNumericArg(const Value& numericArg) const { + if (numericArg.getType() == NumberDecimal) { + Decimal128 argDecimal = numericArg.getDecimal(); + if (argDecimal.isGreater(Decimal128::kNormalizedZero)) + return Value(argDecimal.logarithm(Decimal128(10))); + // Fall through for error case. + } + double argDouble = numericArg.coerceToDouble(); uassert(28761, str::stream() << "$log10's argument must be a positive number, but is " << argDouble, @@ -2672,15 +2779,24 @@ Value ExpressionPow::evaluateInternal(Variables* vars) const { str::stream() << "$pow's exponent must be numeric, not " << typeName(expType), expVal.numeric()); + auto checkNonZeroAndNeg = [](bool isZeroAndNeg) { + uassert(28764, "$pow cannot take a base of 0 and a negative exponent", !isZeroAndNeg); + }; + + // If either argument is decimal, return a decimal. + if (baseType == NumberDecimal || expType == NumberDecimal) { + Decimal128 baseDecimal = baseVal.coerceToDecimal(); + Decimal128 expDecimal = expVal.coerceToDecimal(); + checkNonZeroAndNeg(baseDecimal.isZero() && expDecimal.isNegative()); + return Value(baseDecimal.power(expDecimal)); + } + // pow() will cast args to doubles. double baseDouble = baseVal.coerceToDouble(); double expDouble = expVal.coerceToDouble(); + checkNonZeroAndNeg(baseDouble == 0 && expDouble < 0); - uassert(28764, - "$pow cannot take a base of 0 and a negative exponent", - !(baseDouble == 0 && expDouble < 0)); - - // If either number is a double, return a double. + // If either argument is a double, return a double. if (baseType == NumberDouble || expType == NumberDouble) { return Value(std::pow(baseDouble, expDouble)); } @@ -3414,10 +3530,17 @@ const char* ExpressionSplit::getOpName() const { /* ----------------------- ExpressionSqrt ---------------------------- */ Value ExpressionSqrt::evaluateNumericArg(const Value& numericArg) const { + auto checkArg = [](bool nonNegative) { + uassert(28714, "$sqrt's argument must be greater than or equal to 0", nonNegative); + }; + + if (numericArg.getType() == NumberDecimal) { + Decimal128 argDec = numericArg.getDecimal(); + checkArg(!argDec.isLess(Decimal128::kNormalizedZero)); // NaN returns Nan without error + return Value(argDec.squareRoot()); + } double argDouble = numericArg.coerceToDouble(); - uassert(28714, - "$sqrt's argument must be greater than or equal to 0", - argDouble >= 0 || std::isnan(argDouble)); + checkArg(!(argDouble < 0)); // NaN returns Nan without error return Value(sqrt(argDouble)); } @@ -3638,7 +3761,11 @@ Value ExpressionSubtract::evaluateInternal(Variables* vars) const { BSONType diffType = Value::getWidestNumeric(rhs.getType(), lhs.getType()); - if (diffType == NumberDouble) { + if (diffType == NumberDecimal) { + Decimal128 right = rhs.coerceToDecimal(); + Decimal128 left = lhs.coerceToDecimal(); + return Value(left.subtract(right)); + } else if (diffType == NumberDouble) { double right = rhs.coerceToDouble(); double left = lhs.coerceToDouble(); return Value(left - right); @@ -3838,8 +3965,15 @@ const char* ExpressionToUpper::getOpName() const { Value ExpressionTrunc::evaluateNumericArg(const Value& numericArg) const { // There's no point in truncating integers or longs, it will have no effect. - return numericArg.getType() == NumberDouble ? Value(std::trunc(numericArg.getDouble())) - : numericArg; + switch (numericArg.getType()) { + case NumberDecimal: + return Value(numericArg.getDecimal().quantize(Decimal128::kNormalizedZero, + Decimal128::kRoundTowardZero)); + case NumberDouble: + return Value(std::trunc(numericArg.getDouble())); + default: + return numericArg; + } } REGISTER_EXPRESSION(trunc, ExpressionTrunc::parse); diff --git a/src/mongo/db/pipeline/expression_test.cpp b/src/mongo/db/pipeline/expression_test.cpp index 2487bfc18f1..a72d977ee27 100644 --- a/src/mongo/db/pipeline/expression_test.cpp +++ b/src/mongo/db/pipeline/expression_test.cpp @@ -693,7 +693,7 @@ TEST_F(ExpressionCeilTest, LongArg) { Value(numeric_limits<long long>::max())); } -TEST_F(ExpressionCeilTest, FloatArg) { +TEST_F(ExpressionCeilTest, DoubleArg) { assertEvaluates(Value(2.0), Value(2.0)); assertEvaluates(Value(-2.0), Value(-2.0)); assertEvaluates(Value(0.9), Value(1.0)); @@ -709,6 +709,20 @@ TEST_F(ExpressionCeilTest, FloatArg) { assertEvaluates(Value(smallerThanLong), Value(smallerThanLong)); } +TEST_F(ExpressionCeilTest, DecimalArg) { + assertEvaluates(Value(Decimal128("2")), Value(Decimal128("2.0"))); + assertEvaluates(Value(Decimal128("-2")), Value(Decimal128("-2.0"))); + assertEvaluates(Value(Decimal128("0.9")), Value(Decimal128("1.0"))); + assertEvaluates(Value(Decimal128("0.1")), Value(Decimal128("1.0"))); + assertEvaluates(Value(Decimal128("-1.2")), Value(Decimal128("-1.0"))); + assertEvaluates(Value(Decimal128("-1.7")), Value(Decimal128("-1.0"))); + assertEvaluates(Value(Decimal128("1234567889.000000000000000000000001")), + Value(Decimal128("1234567890"))); + assertEvaluates(Value(Decimal128("-99999999999999999999999999999.99")), + Value(Decimal128("-99999999999999999999999999999.00"))); + assertEvaluates(Value(Decimal128("3.4E-6000")), Value(Decimal128("1"))); +} + TEST_F(ExpressionCeilTest, NullArg) { assertEvaluates(Value(BSONNULL), Value(BSONNULL)); } @@ -737,7 +751,7 @@ TEST_F(ExpressionFloorTest, LongArg) { Value(numeric_limits<long long>::max())); } -TEST_F(ExpressionFloorTest, FloatArg) { +TEST_F(ExpressionFloorTest, DoubleArg) { assertEvaluates(Value(2.0), Value(2.0)); assertEvaluates(Value(-2.0), Value(-2.0)); assertEvaluates(Value(0.9), Value(0.0)); @@ -753,6 +767,20 @@ TEST_F(ExpressionFloorTest, FloatArg) { assertEvaluates(Value(smallerThanLong), Value(smallerThanLong)); } +TEST_F(ExpressionFloorTest, DecimalArg) { + assertEvaluates(Value(Decimal128("2")), Value(Decimal128("2.0"))); + assertEvaluates(Value(Decimal128("-2")), Value(Decimal128("-2.0"))); + assertEvaluates(Value(Decimal128("0.9")), Value(Decimal128("0.0"))); + assertEvaluates(Value(Decimal128("0.1")), Value(Decimal128("0.0"))); + assertEvaluates(Value(Decimal128("-1.2")), Value(Decimal128("-2.0"))); + assertEvaluates(Value(Decimal128("-1.7")), Value(Decimal128("-2.0"))); + assertEvaluates(Value(Decimal128("1234567890.000000000000000000000001")), + Value(Decimal128("1234567890"))); + assertEvaluates(Value(Decimal128("-99999999999999999999999999999.99")), + Value(Decimal128("-100000000000000000000000000000"))); + assertEvaluates(Value(Decimal128("3.4E-6000")), Value(Decimal128("0"))); +} + TEST_F(ExpressionFloorTest, NullArg) { assertEvaluates(Value(BSONNULL), Value(BSONNULL)); } @@ -838,7 +866,7 @@ TEST_F(ExpressionTruncTest, LongArg) { Value(numeric_limits<long long>::max())); } -TEST_F(ExpressionTruncTest, FloatArg) { +TEST_F(ExpressionTruncTest, DoubleArg) { assertEvaluates(Value(2.0), Value(2.0)); assertEvaluates(Value(-2.0), Value(-2.0)); assertEvaluates(Value(0.9), Value(0.0)); @@ -854,6 +882,20 @@ TEST_F(ExpressionTruncTest, FloatArg) { assertEvaluates(Value(smallerThanLong), Value(smallerThanLong)); } +TEST_F(ExpressionTruncTest, DecimalArg) { + assertEvaluates(Value(Decimal128("2")), Value(Decimal128("2.0"))); + assertEvaluates(Value(Decimal128("-2")), Value(Decimal128("-2.0"))); + assertEvaluates(Value(Decimal128("0.9")), Value(Decimal128("0.0"))); + assertEvaluates(Value(Decimal128("0.1")), Value(Decimal128("0.0"))); + assertEvaluates(Value(Decimal128("-1.2")), Value(Decimal128("-1.0"))); + assertEvaluates(Value(Decimal128("-1.7")), Value(Decimal128("-1.0"))); + assertEvaluates(Value(Decimal128("123456789.9999999999999999999999999")), + Value(Decimal128("123456789"))); + assertEvaluates(Value(Decimal128("-99999999999999999999999999999.99")), + Value(Decimal128("-99999999999999999999999999999.00"))); + assertEvaluates(Value(Decimal128("3.4E-6000")), Value(Decimal128("0"))); +} + TEST_F(ExpressionTruncTest, NullArg) { assertEvaluates(Value(BSONNULL), Value(BSONNULL)); } @@ -1038,8 +1080,8 @@ class IntLong : public TwoOperandBase { } }; -/** Adding an int and a long overflows. */ -class IntLongOverflow : public TwoOperandBase { +/** Adding an int and a long produces a double. */ +class IntLongOverflowToDouble : public TwoOperandBase { BSONObj operand1() { return BSON("" << numeric_limits<int>::max()); } @@ -1047,11 +1089,10 @@ class IntLongOverflow : public TwoOperandBase { return BSON("" << numeric_limits<long long>::max()); } BSONObj expectedResult() { - // Aggregation currently treats signed integers as overflowing like unsigned integers do. + // When the result cannot be represented in a NumberLong, a NumberDouble is returned. const auto im = numeric_limits<int>::max(); const auto llm = numeric_limits<long long>::max(); - const auto result = static_cast<long long>(static_cast<unsigned int>(im) + - static_cast<unsigned long long>(llm)); + double result = static_cast<double>(im) + static_cast<double>(llm); return BSON("" << result); } }; @@ -4896,7 +4937,7 @@ public: add<Add::IntInt>(); add<Add::IntIntNoOverflow>(); add<Add::IntLong>(); - add<Add::IntLongOverflow>(); + add<Add::IntLongOverflowToDouble>(); add<Add::IntDouble>(); add<Add::IntDate>(); add<Add::LongDouble>(); diff --git a/src/mongo/db/pipeline/value.cpp b/src/mongo/db/pipeline/value.cpp index 8a37c51067c..37b3b88afae 100644 --- a/src/mongo/db/pipeline/value.cpp +++ b/src/mongo/db/pipeline/value.cpp @@ -971,8 +971,6 @@ BSONType Value::getWidestNumeric(BSONType lType, BSONType rType) { return Undefined; } -// TODO: Add Decimal128 support to Value::integral() -// SERVER-19735 bool Value::integral() const { switch (getType()) { case NumberInt: @@ -984,6 +982,13 @@ bool Value::integral() const { return (_storage.doubleValue <= numeric_limits<int>::max() && _storage.doubleValue >= numeric_limits<int>::min() && _storage.doubleValue == static_cast<int>(_storage.doubleValue)); + case NumberDecimal: { + // If we are able to convert the decimal to an int32_t without an rounding errors, + // then it is integral. + uint32_t signalingFlags = Decimal128::kNoFlag; + (void)_storage.getDecimal().toInt(&signalingFlags); + return signalingFlags == Decimal128::kNoFlag; + } default: return false; } diff --git a/src/mongo/db/pipeline/value.h b/src/mongo/db/pipeline/value.h index c6d4b90c0cd..5ac5c61180b 100644 --- a/src/mongo/db/pipeline/value.h +++ b/src/mongo/db/pipeline/value.h @@ -119,11 +119,9 @@ public: } /// true if type represents a number - // TODO: Add _storage.type == NumberDecimal - // SERVER-19735 bool numeric() const { return _storage.type == NumberDouble || _storage.type == NumberLong || - _storage.type == NumberInt; + _storage.type == NumberInt || _storage.type == NumberDecimal; } /** |