From 00ed8f3b275971093ddd2ee7d3ab558904e28af0 Mon Sep 17 00:00:00 2001 From: Geert Bosch Date: Wed, 1 Jun 2016 22:52:44 -0400 Subject: SERVER-19735: Add support for decimal type in aggregation --- etc/ubsan.blacklist | 7 - jstests/aggregation/bugs/server18427.js | 50 +++++- src/mongo/db/pipeline/SConscript | 1 + src/mongo/db/pipeline/accumulator.h | 17 +- src/mongo/db/pipeline/accumulator_avg.cpp | 80 ++++++--- src/mongo/db/pipeline/accumulator_sum.cpp | 105 ++++++++---- src/mongo/db/pipeline/accumulator_test.cpp | 89 +++++----- src/mongo/db/pipeline/expression.cpp | 252 ++++++++++++++++++++++------- src/mongo/db/pipeline/expression_test.cpp | 59 +++++-- src/mongo/db/pipeline/value.cpp | 9 +- src/mongo/db/pipeline/value.h | 4 +- 11 files changed, 504 insertions(+), 169 deletions(-) diff --git a/etc/ubsan.blacklist b/etc/ubsan.blacklist index 4b5278b0a36..efd9562dacf 100644 --- a/etc/ubsan.blacklist +++ b/etc/ubsan.blacklist @@ -19,13 +19,6 @@ src:src/third_party/zlib-*/* # which trips up UBSAN. src:src/mongo/db/storage/mmap_v1/* -# See SERVER-23713. The pipeline arith expressions and accumulators need -# to be re-written to avoid undefined behavior. -fun:_ZNK5mongo13ExpressionAdd16evaluateInternalEPNS_9VariablesE -fun:_ZNK5mongo18ExpressionMultiply16evaluateInternalEPNS_9VariablesE -fun:_ZN5mongo14AccumulatorSum15processInternalERKNS_5ValueEb -fun:_ZNK5mongo5Value12coerceToLongEv - # Blacklisting these functions due to a bug in libstdc++: # http://stackoverflow.com/questions/30122500/is-this-code-really-undefined-as-clang-seems-to-indicate fun:_ZStaNRSt13_Ios_FmtflagsS_ diff --git a/jstests/aggregation/bugs/server18427.js b/jstests/aggregation/bugs/server18427.js index 35fcef8a4ac..25c34a83bb1 100644 --- a/jstests/aggregation/bugs/server18427.js +++ b/jstests/aggregation/bugs/server18427.js @@ -9,6 +9,9 @@ load('jstests/aggregation/extras/utils.js'); coll.drop(); assert.writeOK(coll.insert({_id: 0})); + var decimalE = NumberDecimal("2.718281828459045235360287471352662"); + var decimal1overE = NumberDecimal("0.3678794411714423215955237701614609"); + // Helper for testing that op returns expResult. function testOp(op, expResult) { var pipeline = [{$project: {_id: 0, result: op}}]; @@ -18,9 +21,15 @@ load('jstests/aggregation/extras/utils.js'); // $log, $log10, $ln. // Valid input: numeric/null/NaN, base positive and not equal to 1, arg positive. + // - NumberDouble testOp({$log: [10, 10]}, 1); testOp({$log10: [10]}, 1); testOp({$ln: [Math.E]}, 1); + // - NumberDecimal + testOp({$log: [NumberDecimal("10"), NumberDecimal("10")]}, NumberDecimal("1")); + testOp({$log10: [NumberDecimal("10")]}, NumberDecimal("1")); + // The below answer is actually correct: the input is an approximation of E + testOp({$ln: [decimalE]}, NumberDecimal("0.9999999999999999999999999999999998")); // All types converted to doubles. testOp({$log: [NumberLong("10"), NumberLong("10")]}, 1); testOp({$log10: [NumberLong("10")]}, 1); @@ -30,11 +39,15 @@ load('jstests/aggregation/extras/utils.js'); // Null inputs result in null. testOp({$log: [null, 10]}, null); testOp({$log: [10, null]}, null); + testOp({$log: [null, NumberDecimal(10)]}, null); + testOp({$log: [NumberDecimal(10), null]}, null); testOp({$log10: [null]}, null); testOp({$ln: [null]}, null); // NaN inputs result in NaN. testOp({$log: [NaN, 10]}, NaN); testOp({$log: [10, NaN]}, NaN); + testOp({$log: [NaN, NumberDecimal(10)]}, NaN); + testOp({$log: [NumberDecimal(10), NaN]}, NaN); testOp({$log10: [NaN]}, NaN); testOp({$ln: [NaN]}, NaN); @@ -50,13 +63,24 @@ load('jstests/aggregation/extras/utils.js'); assertErrorCode(coll, [{$project: {log: {$log: [5, 0]}}}], 28759); assertErrorCode(coll, [{$project: {log10: {$log10: [0]}}}], 28761); assertErrorCode(coll, [{$project: {ln: {$ln: [0]}}}], 28766); + assertErrorCode(coll, [{$project: {log: {$log: [NumberDecimal(0), NumberDecimal(5)]}}}], 28758); + assertErrorCode(coll, [{$project: {log: {$log: [NumberDecimal(5), NumberDecimal(0)]}}}], 28759); + assertErrorCode(coll, [{$project: {log10: {$log10: [NumberDecimal(0)]}}}], 28761); + assertErrorCode(coll, [{$project: {ln: {$ln: [NumberDecimal(0)]}}}], 28766); // Args/bases cannot be negative. assertErrorCode(coll, [{$project: {log: {$log: [-1, 5]}}}], 28758); assertErrorCode(coll, [{$project: {log: {$log: [5, -1]}}}], 28759); assertErrorCode(coll, [{$project: {log10: {$log10: [-1]}}}], 28761); assertErrorCode(coll, [{$project: {ln: {$ln: [-1]}}}], 28766); + assertErrorCode( + coll, [{$project: {log: {$log: [NumberDecimal(-1), NumberDecimal(5)]}}}], 28758); + assertErrorCode( + coll, [{$project: {log: {$log: [NumberDecimal(5), NumberDecimal(-1)]}}}], 28759); + assertErrorCode(coll, [{$project: {log10: {$log10: [NumberDecimal(-1)]}}}], 28761); + assertErrorCode(coll, [{$project: {ln: {$ln: [NumberDecimal(-1)]}}}], 28766); // Base can't equal 1. assertErrorCode(coll, [{$project: {log: {$log: [5, 1]}}}], 28759); + assertErrorCode(coll, [{$project: {log: {$log: [NumberDecimal(5), NumberDecimal(1)]}}}], 28759); // $pow, $exp. @@ -68,6 +92,18 @@ load('jstests/aggregation/extras/utils.js'); testOp({$pow: [-2, 2]}, 4); testOp({$pow: [NumberInt("2"), 2]}, 4); testOp({$pow: [-2, NumberInt("2")]}, 4); + // $pow -- if either input is a NumberDecimal, return a NumberDecimal + testOp({$pow: [NumberDecimal("10.0"), -2]}, + NumberDecimal("0.01000000000000000000000000000000000")); + testOp({$pow: [0.5, NumberDecimal("-1")]}, + NumberDecimal("2.000000000000000000000000000000000")); + testOp({$pow: [-2, NumberDecimal("2")]}, NumberDecimal("4.000000000000000000000000000000000")); + testOp({$pow: [NumberInt("2"), NumberDecimal("2")]}, + NumberDecimal("4.000000000000000000000000000000000")); + testOp({$pow: [NumberDecimal("-2.0"), NumberInt("2")]}, + NumberDecimal("4.000000000000000000000000000000000")); + testOp({$pow: [NumberDecimal("10.0"), 2]}, + NumberDecimal("100.0000000000000000000000000000000")); // If exponent is negative and base not -1, 0, or 1, return a double. testOp({$pow: [NumberLong("2"), NumberLong("-1")]}, 1 / 2); @@ -78,6 +114,9 @@ load('jstests/aggregation/extras/utils.js'); // If result would overflow a long, return a double. testOp({$pow: [NumberInt("2"), NumberLong("63")]}, 9223372036854776000); + // Exact decimal result + testOp({$pow: [NumberInt("5"), NumberDecimal("-112")]}, + NumberDecimal("5192296858534827628530496329220096E-112")); // Result would be incorrect if double were returned. testOp({$pow: [NumberInt("3"), NumberInt("35")]}, NumberLong("50031545098999707")); @@ -91,16 +130,23 @@ load('jstests/aggregation/extras/utils.js'); // Else return an int if it fits. testOp({$pow: [NumberInt("4"), NumberInt("2")]}, 16); - // $exp always returns doubles, since e is a double. + // $exp always returns doubles for non-zero non-decimal inputs, since e is a double. testOp({$exp: [NumberInt("-1")]}, 1 / Math.E); testOp({$exp: [NumberLong("1")]}, Math.E); + // $exp returns decimal results for decimal inputs + testOp({$exp: [NumberDecimal("-1")]}, decimal1overE); + testOp({$exp: [NumberDecimal("1")]}, decimalE); // Null input results in null. testOp({$pow: [null, 2]}, null); testOp({$pow: [1 / 2, null]}, null); + testOp({$pow: [null, NumberDecimal(2)]}, null); + testOp({$pow: [NumberDecimal("0.5"), null]}, null); testOp({$exp: [null]}, null); // NaN input results in NaN. testOp({$pow: [NaN, 2]}, NaN); testOp({$pow: [1 / 2, NaN]}, NaN); + testOp({$pow: [NaN, NumberDecimal(2)]}, NumberDecimal("NaN")); + testOp({$pow: [NumberDecimal("0.5"), NaN]}, NumberDecimal("NaN")); testOp({$exp: [NaN]}, NaN); // Invalid inputs - non-numeric/non-null types, or 0 to a negative exponent. @@ -108,4 +154,6 @@ load('jstests/aggregation/extras/utils.js'); assertErrorCode(coll, [{$project: {pow: {$pow: ["string", 5]}}}], 28762); assertErrorCode(coll, [{$project: {pow: {$pow: [5, "string"]}}}], 28763); assertErrorCode(coll, [{$project: {exp: {$exp: ["string"]}}}], 28765); + assertErrorCode(coll, [{$project: {pow: {$pow: [NumberDecimal(0), NumberLong("-1")]}}}], 28764); + assertErrorCode(coll, [{$project: {pow: {$pow: ["string", NumberDecimal(5)]}}}], 28762); }()); diff --git a/src/mongo/db/pipeline/SConscript b/src/mongo/db/pipeline/SConscript index eef0b7ecada..8525401fb72 100644 --- a/src/mongo/db/pipeline/SConscript +++ b/src/mongo/db/pipeline/SConscript @@ -91,6 +91,7 @@ env.Library( ], LIBDEPS=[ 'document_value', + '$BUILD_DIR/mongo/util/summation', 'expression', 'field_path', ] diff --git a/src/mongo/db/pipeline/accumulator.h b/src/mongo/db/pipeline/accumulator.h index 7dd604b82b4..eb09459f942 100644 --- a/src/mongo/db/pipeline/accumulator.h +++ b/src/mongo/db/pipeline/accumulator.h @@ -38,6 +38,7 @@ #include "mongo/bson/bsontypes.h" #include "mongo/db/pipeline/value.h" #include "mongo/stdx/functional.h" +#include "mongo/util/summation.h" namespace mongo { /** @@ -194,9 +195,9 @@ public: } private: - BSONType totalType; - long long longTotal; - double doubleTotal; + BSONType totalType = NumberInt; + DoubleDoubleSummation nonDecimalTotal; + Decimal128 decimalTotal; }; @@ -268,7 +269,15 @@ public: static boost::intrusive_ptr create(); private: - double _total; + /** + * The total of all values is partitioned between those that are decimals, and those that are + * not decimals, so the decimal total needs to add the non-decimal. + */ + Decimal128 _getDecimalTotal() const; + + bool _isDecimal; + DoubleDoubleSummation _nonDecimalTotal; + Decimal128 _decimalTotal; long long _count; }; diff --git a/src/mongo/db/pipeline/accumulator_avg.cpp b/src/mongo/db/pipeline/accumulator_avg.cpp index ed11d81ecc0..f9faf8d359a 100644 --- a/src/mongo/db/pipeline/accumulator_avg.cpp +++ b/src/mongo/db/pipeline/accumulator_avg.cpp @@ -33,6 +33,7 @@ #include "mongo/db/pipeline/expression.h" #include "mongo/db/pipeline/expression_context.h" #include "mongo/db/pipeline/value.h" +#include "mongo/platform/decimal128.h" namespace mongo { @@ -47,48 +48,85 @@ const char* AccumulatorAvg::getOpName() const { namespace { const char subTotalName[] = "subTotal"; +const char subTotalErrorName[] = "subTotalError"; // Used for extra precision const char countName[] = "count"; -} +} // namespace void AccumulatorAvg::processInternal(const Value& input, bool merging) { - if (!merging) { - // non numeric types have no impact on average - if (!input.numeric()) - return; - - _total += input.getDouble(); - _count += 1; - } else { - // We expect an object that contains both a subtotal and a count. - // This is what getValue(true) produced below. + if (merging) { + // We expect an object that contains both a subtotal and a count. Additionally there may + // be an error value, that allows for additional precision. + // 'input' is what getValue(true) produced below. verify(input.getType() == Object); - _total += input[subTotalName].getDouble(); - _count += input[countName].getLong(); + // We're recursively adding the subtotal to get the proper type treatment, but this only + // increments the count by one, so adjust the count afterwards. Similarly for 'error'. + processInternal(input[subTotalName], false); + _count += input[countName].getLong() - 1; + Value error = input[subTotalErrorName]; + if (!error.missing()) { + processInternal(error, false); + _count--; // The error correction only adjusts the total, not the number of items. + } + return; + } + + switch (input.getType()) { + case NumberDecimal: + _decimalTotal = _decimalTotal.add(input.getDecimal()); + _isDecimal = true; + break; + case NumberLong: + // Avoid summation using double as that loses precision. + _nonDecimalTotal.addLong(input.getLong()); + break; + case NumberInt: + case NumberDouble: + _nonDecimalTotal.addDouble(input.getDouble()); + break; + default: + dassert(!input.numeric()); + return; } + _count++; } intrusive_ptr AccumulatorAvg::create() { return new AccumulatorAvg(); } +Decimal128 AccumulatorAvg::_getDecimalTotal() const { + return _decimalTotal.add(_nonDecimalTotal.getDecimal()); +} + Value AccumulatorAvg::getValue(bool toBeMerged) const { - if (!toBeMerged) { - if (_count == 0) - return Value(BSONNULL); + if (toBeMerged) { + if (_isDecimal) + return Value(Document{{subTotalName, _getDecimalTotal()}, {countName, _count}}); - return Value(_total / static_cast(_count)); - } else { - return Value(DOC(subTotalName << _total << countName << _count)); + double total, error; + std::tie(total, error) = _nonDecimalTotal.getDoubleDouble(); + return Value( + Document{{subTotalName, total}, {countName, _count}, {subTotalErrorName, error}}); } + + if (_count == 0) + return Value(BSONNULL); + + if (_isDecimal) + return Value(_getDecimalTotal().divide(Decimal128(static_cast(_count)))); + + return Value(_nonDecimalTotal.getDouble() / static_cast(_count)); } -AccumulatorAvg::AccumulatorAvg() : _total(0), _count(0) { +AccumulatorAvg::AccumulatorAvg() : _isDecimal(false), _count(0) { // This is a fixed size Accumulator so we never need to update this _memUsageBytes = sizeof(*this); } void AccumulatorAvg::reset() { - _total = 0; + _isDecimal = false; + _nonDecimalTotal = {}; + _decimalTotal = {}; _count = 0; } } diff --git a/src/mongo/db/pipeline/accumulator_sum.cpp b/src/mongo/db/pipeline/accumulator_sum.cpp index c064fe52f04..1255c3fde8b 100644 --- a/src/mongo/db/pipeline/accumulator_sum.cpp +++ b/src/mongo/db/pipeline/accumulator_sum.cpp @@ -28,9 +28,13 @@ #include "mongo/platform/basic.h" +#include +#include + #include "mongo/db/pipeline/accumulator.h" #include "mongo/db/pipeline/expression.h" #include "mongo/db/pipeline/value.h" +#include "mongo/util/summation.h" namespace mongo { @@ -43,24 +47,38 @@ const char* AccumulatorSum::getOpName() const { return "$sum"; } +namespace { +const char subTotalName[] = "subTotal"; +const char subTotalErrorName[] = "subTotalError"; // Used for extra precision. +} // namespace + + void AccumulatorSum::processInternal(const Value& input, bool merging) { - // do nothing with non numeric types - if (!input.numeric()) + if (!input.numeric()) { + if (merging && input.getType() == Object) { + // Process merge document, see getValue() below. + nonDecimalTotal.addDouble( + input[subTotalName].getDouble()); // Sum without adjusting type. + processInternal(input[subTotalErrorName], false); // Sum adjusting for type of error. + } return; + } - // upgrade to the widest type required to hold the result + // Upgrade to the widest type required to hold the result. totalType = Value::getWidestNumeric(totalType, input.getType()); - - if (totalType == NumberInt || totalType == NumberLong) { - long long v = input.coerceToLong(); - longTotal += v; - doubleTotal += v; - } else if (totalType == NumberDouble) { - double v = input.coerceToDouble(); - doubleTotal += v; - } else { - // non numerics should have returned above so we should never get here - verify(false); + switch (input.getType()) { + case NumberInt: + case NumberLong: + nonDecimalTotal.addLong(input.coerceToLong()); + break; + case NumberDouble: + nonDecimalTotal.addDouble(input.getDouble()); + break; + case NumberDecimal: + decimalTotal = decimalTotal.add(input.coerceToDecimal()); + break; + default: + MONGO_UNREACHABLE; } } @@ -69,25 +87,58 @@ intrusive_ptr AccumulatorSum::create() { } Value AccumulatorSum::getValue(bool toBeMerged) const { - if (totalType == NumberLong) { - return Value(longTotal); - } else if (totalType == NumberDouble) { - return Value(doubleTotal); - } else if (totalType == NumberInt) { - return Value::createIntOrLong(longTotal); - } else { - massert(16000, "$sum resulted in a non-numeric type", false); + switch (totalType) { + case NumberInt: + if (nonDecimalTotal.fitsLong()) + return Value::createIntOrLong(nonDecimalTotal.getLong()); + // Fallthrough. + case NumberLong: + if (nonDecimalTotal.fitsLong()) + return Value(nonDecimalTotal.getLong()); + if (toBeMerged) { + // The value was too large for a NumberLong, so output a document with two values + // adding up to the desired total. Older MongoDB versions used to ignore signed + // integer overflow and cause undefined behavior, that in practice resulted in + // values that would wrap around modulo 2**64. Now an older mongos with a newer + // mongod will yield an error that $sum resulted in a non-numeric type, which is + // OK for this case. Output the error using the totalType, so in the future we can + // determine the correct totalType for the sum. For the error to exceed 2**63, + // more than 2**53 integers would have to be summed, which is impossible. + double total; + double error; + std::tie(total, error) = nonDecimalTotal.getDoubleDouble(); + long long llerror = static_cast(error); + return Value(DOC(subTotalName << total << subTotalErrorName << llerror)); + } + // Sum doesn't fit a NumberLong, so return a NumberDouble instead. + return Value(nonDecimalTotal.getDouble()); + + case NumberDouble: + return Value(nonDecimalTotal.getDouble()); + case NumberDecimal: { + double sum, error; + std::tie(sum, error) = nonDecimalTotal.getDoubleDouble(); + Decimal128 total; // zero + if (sum != 0) { + total = total.add(Decimal128(sum, Decimal128::kRoundTo34Digits)); + total = total.add(Decimal128(error, Decimal128::kRoundTo34Digits)); + } + total = total.add(decimalTotal); + return Value(total); + } + default: + MONGO_UNREACHABLE; } } -AccumulatorSum::AccumulatorSum() : totalType(NumberInt), longTotal(0), doubleTotal(0) { - // This is a fixed size Accumulator so we never need to update this +AccumulatorSum::AccumulatorSum() { + // This is a fixed size Accumulator so we never need to update this. _memUsageBytes = sizeof(*this); } void AccumulatorSum::reset() { totalType = NumberInt; - longTotal = 0; - doubleTotal = 0; -} + nonDecimalTotal = {}; + decimalTotal = {}; } +} // namespace mongo diff --git a/src/mongo/db/pipeline/accumulator_test.cpp b/src/mongo/db/pipeline/accumulator_test.cpp index 928eb2e868e..e354d71d8f9 100644 --- a/src/mongo/db/pipeline/accumulator_test.cpp +++ b/src/mongo/db/pipeline/accumulator_test.cpp @@ -96,38 +96,55 @@ static void assertExpectedResults( TEST(Accumulators, Avg) { assertExpectedResults( "$avg", - {// No documents evaluated. - {{}, Value(BSONNULL)}, - - // One int value is converted to double. - {{Value(3)}, Value(3.0)}, - // One long value is converted to double. - {{Value(-4LL)}, Value(-4.0)}, - // One double value. - {{Value(22.6)}, Value(22.6)}, - - // Averaging two ints. - {{Value(10), Value(11)}, Value(10.5)}, - // Averaging two longs. - {{Value(10LL), Value(11LL)}, Value(10.5)}, - // Averaging two doubles. - {{Value(10.0), Value(11.0)}, Value(10.5)}, - - // The average of an int and a double is a double. - {{Value(10), Value(11.0)}, Value(10.5)}, - // The average of a long and a double is a double. - {{Value(5LL), Value(1.0)}, Value(3.0)}, - // The average of an int and a long is a double. - {{Value(5), Value(3LL)}, Value(4.0)}, - // Averaging an int, long, and double. - {{Value(1), Value(2LL), Value(6.0)}, Value(3.0)}, - - // Unlike $sum, two ints do not overflow in the 'total' portion of the average. - {{Value(numeric_limits::max()), Value(numeric_limits::max())}, - Value(static_cast(numeric_limits::max()))}, - // Two longs do overflow in the 'total' portion of the average. - {{Value(numeric_limits::max()), Value(numeric_limits::max())}, - Value(static_cast(numeric_limits::max()))}}); + { + // No documents evaluated. + {{}, Value(BSONNULL)}, + + // One int value is converted to double. + {{Value(3)}, Value(3.0)}, + // One long value is converted to double. + {{Value(-4LL)}, Value(-4.0)}, + // One double value. + {{Value(22.6)}, Value(22.6)}, + + // Averaging two ints. + {{Value(10), Value(11)}, Value(10.5)}, + // Averaging two longs. + {{Value(10LL), Value(11LL)}, Value(10.5)}, + // Averaging two doubles. + {{Value(10.0), Value(11.0)}, Value(10.5)}, + + // The average of an int and a double is a double. + {{Value(10), Value(11.0)}, Value(10.5)}, + // The average of a long and a double is a double. + {{Value(5LL), Value(1.0)}, Value(3.0)}, + // The average of an int and a long is a double. + {{Value(5), Value(3LL)}, Value(4.0)}, + // Averaging an int, long, and double. + {{Value(1), Value(2LL), Value(6.0)}, Value(3.0)}, + + // Unlike $sum, two ints do not overflow in the 'total' portion of the average. + {{Value(numeric_limits::max()), Value(numeric_limits::max())}, + Value(static_cast(numeric_limits::max()))}, + // Two longs do overflow in the 'total' portion of the average. + {{Value(numeric_limits::max()), Value(numeric_limits::max())}, + Value(static_cast(numeric_limits::max()))}, + + // Averaging two decimals. + {{Value(Decimal128("-1234567890.1234567889")), + Value(Decimal128("-1234567890.1234567891"))}, + Value(Decimal128("-1234567890.1234567890"))}, + + // Averaging two longs and a decimal results in an accurate decimal result. + {{Value(1234567890123456788LL), + Value(1234567890123456789LL), + Value(Decimal128("1234567890123456790.037037036703702"))}, + Value(Decimal128("1234567890123456789.012345678901234"))}, + + // Averaging a double and a decimal + {{Value(1.0E22), Value(Decimal128("9999999999999999999999.9999999999"))}, + Value(Decimal128("9999999999999999999999.99999999995"))}, + }); } TEST(Accumulators, First) { @@ -249,12 +266,12 @@ TEST(Accumulators, Sum) { // An int and a double do not trigger an int overflow. {{Value(numeric_limits::max()), Value(1.0)}, Value(static_cast(numeric_limits::max()) + 1.0)}, - // An int and a long overflow. + // An int and a long overflow into a double. {{Value(1), Value(numeric_limits::max())}, - Value(numeric_limits::min())}, - // Two longs overflow. + Value(-static_cast(numeric_limits::min()))}, + // Two longs overflow into a double. {{Value(numeric_limits::max()), Value(numeric_limits::max())}, - Value(-2LL)}, + Value(static_cast(numeric_limits::max()) * 2)}, // A long and a double do not trigger a long overflow. {{Value(numeric_limits::max()), Value(1.0)}, Value(numeric_limits::max() + 1.0)}, diff --git a/src/mongo/db/pipeline/expression.cpp b/src/mongo/db/pipeline/expression.cpp index 9c68924d476..1c976df8f42 100644 --- a/src/mongo/db/pipeline/expression.cpp +++ b/src/mongo/db/pipeline/expression.cpp @@ -41,8 +41,10 @@ #include "mongo/db/pipeline/expression_context.h" #include "mongo/db/pipeline/value.h" #include "mongo/platform/bits.h" +#include "mongo/platform/decimal128.h" #include "mongo/util/mongoutils/str.h" #include "mongo/util/string_map.h" +#include "mongo/util/summation.h" namespace mongo { using Parser = Expression::Parser; @@ -414,6 +416,8 @@ Value ExpressionAbs::evaluateNumericArg(const Value& numericArg) const { BSONType type = numericArg.getType(); if (type == NumberDouble) { return Value(std::abs(numericArg.getDouble())); + } else if (type == NumberDecimal) { + return Value(numericArg.getDecimal().toAbs()); } else { long long num = numericArg.getLong(); uassert(28680, @@ -432,14 +436,12 @@ const char* ExpressionAbs::getOpName() const { /* ------------------------- ExpressionAdd ----------------------------- */ Value ExpressionAdd::evaluateInternal(Variables* vars) const { - /* - We'll try to return the narrowest possible result value. To do that - without creating intermediate Values, do the arithmetic for double - and integral types in parallel, tracking the current narrowest - type. - */ - double doubleTotal = 0; - long long longTotal = 0; + // We'll try to return the narrowest possible result value while avoiding overflow, loss + // of precision due to intermediate rounding or implicit use of decimal types. To do that, + // compute a compensated sum for non-decimal values and a separate decimal sum for decimal + // values, and track the current narrowest type. + DoubleDoubleSummation nonDecimalTotal; + Decimal128 decimalTotal; BSONType totalType = NumberInt; bool haveDate = false; @@ -447,40 +449,64 @@ Value ExpressionAdd::evaluateInternal(Variables* vars) const { for (size_t i = 0; i < n; ++i) { Value val = vpOperand[i]->evaluateInternal(vars); - if (val.numeric()) { - totalType = Value::getWidestNumeric(totalType, val.getType()); - - doubleTotal += val.coerceToDouble(); - longTotal += val.coerceToLong(); - } else if (val.getType() == Date) { - uassert(16612, "only one date allowed in an $add expression", !haveDate); - haveDate = true; - - // We don't manipulate totalType here. - - longTotal += val.getDate(); - doubleTotal += val.getDate(); - } else if (val.nullish()) { - return Value(BSONNULL); - } else { - uasserted(16554, - str::stream() << "$add only supports numeric or date types, not " - << typeName(val.getType())); + switch (val.getType()) { + case NumberDecimal: + decimalTotal = decimalTotal.add(val.getDecimal()); + totalType = NumberDecimal; + break; + case NumberDouble: + nonDecimalTotal.addDouble(val.getDouble()); + if (totalType != NumberDecimal) + totalType = NumberDouble; + break; + case NumberLong: + nonDecimalTotal.addLong(val.getLong()); + if (totalType == NumberInt) + totalType = NumberLong; + break; + case NumberInt: + nonDecimalTotal.addDouble(val.getInt()); + break; + case Date: + uassert(16612, "only one date allowed in an $add expression", !haveDate); + haveDate = true; + nonDecimalTotal.addLong(val.getDate()); + break; + default: + uassert(16554, + str::stream() << "$add only supports numeric or date types, not " + << typeName(val.getType()), + val.nullish()); + return Value(BSONNULL); } } if (haveDate) { - if (totalType == NumberDouble) - longTotal = static_cast(doubleTotal); + int64_t longTotal; + if (totalType == NumberDecimal) { + longTotal = decimalTotal.add(nonDecimalTotal.getDecimal()).toLong(); + } else { + uassert(ErrorCodes::Overflow, "date overflow in $add", nonDecimalTotal.fitsLong()); + longTotal = nonDecimalTotal.getLong(); + } return Value(Date_t::fromMillisSinceEpoch(longTotal)); - } else if (totalType == NumberLong) { - return Value(longTotal); - } else if (totalType == NumberDouble) { - return Value(doubleTotal); - } else if (totalType == NumberInt) { - return Value::createIntOrLong(longTotal); - } else { - massert(16417, "$add resulted in a non-numeric type", false); + } + switch (totalType) { + case NumberDecimal: + return Value(decimalTotal.add(nonDecimalTotal.getDecimal())); + case NumberLong: + dassert(nonDecimalTotal.isInteger()); + if (nonDecimalTotal.fitsLong()) + return Value(nonDecimalTotal.getLong()); + // Fallthrough. + case NumberInt: + if (nonDecimalTotal.fitsLong()) + return Value::createIntOrLong(nonDecimalTotal.getLong()); + // Fallthrough. + case NumberDouble: + return Value(nonDecimalTotal.getDouble()); + default: + massert(16417, "$add resulted in a non-numeric type", false); } } @@ -677,8 +703,16 @@ const char* ExpressionArrayElemAt::getOpName() const { Value ExpressionCeil::evaluateNumericArg(const Value& numericArg) const { // There's no point in taking the ceiling of integers or longs, it will have no effect. - return numericArg.getType() == NumberDouble ? Value(std::ceil(numericArg.getDouble())) - : numericArg; + switch (numericArg.getType()) { + case NumberDouble: + return Value(std::ceil(numericArg.getDouble())); + case NumberDecimal: + // Round toward the nearest decimal with a zero exponent in the positive direction. + return Value(numericArg.getDecimal().quantize(Decimal128::kNormalizedZero, + Decimal128::kRoundTowardPositive)); + default: + return numericArg; + } } REGISTER_EXPRESSION(ceil, ExpressionCeil::parse); @@ -1203,10 +1237,20 @@ Value ExpressionDivide::evaluateInternal(Variables* vars) const { Value lhs = vpOperand[0]->evaluateInternal(vars); Value rhs = vpOperand[1]->evaluateInternal(vars); + auto assertNonZero = [](bool nonZero) { uassert(16608, "can't $divide by zero", nonZero); }; + if (lhs.numeric() && rhs.numeric()) { + // If, and only if, either side is decimal, return decimal. + if (lhs.getType() == NumberDecimal || rhs.getType() == NumberDecimal) { + Decimal128 numer = lhs.coerceToDecimal(); + Decimal128 denom = rhs.coerceToDecimal(); + assertNonZero(!denom.isZero()); + return Value(numer.divide(denom)); + } + double numer = lhs.coerceToDouble(); double denom = rhs.coerceToDouble(); - uassert(16608, "can't $divide by zero", denom != 0); + assertNonZero(denom != 0.0); return Value(numer / denom); } else if (lhs.nullish() || rhs.nullish()) { @@ -1228,7 +1272,10 @@ const char* ExpressionDivide::getOpName() const { /* ----------------------- ExpressionExp ---------------------------- */ Value ExpressionExp::evaluateNumericArg(const Value& numericArg) const { - // exp() always returns a double since e is a double. + // $exp always returns either a double or a decimal number, as e is irrational. + if (numericArg.getType() == NumberDecimal) + return Value(numericArg.coerceToDecimal().exponential()); + return Value(exp(numericArg.coerceToDouble())); } @@ -1740,8 +1787,16 @@ void ExpressionFilter::addDependencies(DepsTracker* deps, vector* path) Value ExpressionFloor::evaluateNumericArg(const Value& numericArg) const { // There's no point in taking the floor of integers or longs, it will have no effect. - return numericArg.getType() == NumberDouble ? Value(std::floor(numericArg.getDouble())) - : numericArg; + switch (numericArg.getType()) { + case NumberDouble: + return Value(std::floor(numericArg.getDouble())); + case NumberDecimal: + // Round toward the nearest decimal with a zero exponent in the negative direction. + return Value(numericArg.getDecimal().quantize(Decimal128::kNormalizedZero, + Decimal128::kRoundTowardNegative)); + default: + return numericArg; + } } REGISTER_EXPRESSION(floor, ExpressionFloor::parse); @@ -2028,10 +2083,20 @@ Value ExpressionMod::evaluateInternal(Variables* vars) const { BSONType rightType = rhs.getType(); if (lhs.numeric() && rhs.numeric()) { + auto assertNonZero = [](bool isZero) { uassert(16610, "can't $mod by zero", !isZero); }; + + // If either side is decimal, perform the operation in decimal. + if (leftType == NumberDecimal || rightType == NumberDecimal) { + Decimal128 left = lhs.coerceToDecimal(); + Decimal128 right = rhs.coerceToDecimal(); + assertNonZero(right.isZero()); + + return Value(left.modulo(right)); + } + // ensure we aren't modding by 0 double right = rhs.coerceToDouble(); - - uassert(16610, "can't $mod by 0", right != 0); + assertNonZero(right == 0); if (leftType == NumberDouble || (rightType == NumberDouble && !rhs.integral())) { // Need to do fmod. Integer-valued double case is handled below. @@ -2088,6 +2153,8 @@ Value ExpressionMultiply::evaluateInternal(Variables* vars) const { */ double doubleProduct = 1; long long longProduct = 1; + Decimal128 decimalProduct; // This will be initialized on encountering the first decimal. + BSONType productType = NumberInt; const size_t n = vpOperand.size(); @@ -2095,10 +2162,23 @@ Value ExpressionMultiply::evaluateInternal(Variables* vars) const { Value val = vpOperand[i]->evaluateInternal(vars); if (val.numeric()) { + BSONType oldProductType = productType; productType = Value::getWidestNumeric(productType, val.getType()); - - doubleProduct *= val.coerceToDouble(); - longProduct *= val.coerceToLong(); + if (productType == NumberDecimal) { + // On finding the first decimal, convert the partial product to decimal. + if (oldProductType != NumberDecimal) { + decimalProduct = oldProductType == NumberDouble + ? Decimal128(doubleProduct, Decimal128::kRoundTo15Digits) + : Decimal128(static_cast(longProduct)); + } + decimalProduct = decimalProduct.multiply(val.coerceToDecimal()); + } else { + doubleProduct *= val.coerceToDouble(); + if (mongoSignedMultiplyOverflow64(longProduct, val.coerceToLong(), &longProduct)) { + // The 'longProduct' would have overflowed, so we're abandoning it. + productType = NumberDouble; + } + } } else if (val.nullish()) { return Value(BSONNULL); } else { @@ -2114,6 +2194,8 @@ Value ExpressionMultiply::evaluateInternal(Variables* vars) const { return Value(longProduct); else if (productType == NumberInt) return Value::createIntOrLong(longProduct); + else if (productType == NumberDecimal) + return Value(decimalProduct); else massert(16418, "$multiply resulted in a non-numeric type", false); } @@ -2390,6 +2472,12 @@ const char* ExpressionIndexOfCP::getOpName() const { /* ----------------------- ExpressionLn ---------------------------- */ Value ExpressionLn::evaluateNumericArg(const Value& numericArg) const { + if (numericArg.getType() == NumberDecimal) { + Decimal128 argDecimal = numericArg.getDecimal(); + if (argDecimal.isGreater(Decimal128::kNormalizedZero)) + return Value(argDecimal.logarithm()); + // Fall through for error case. + } double argDouble = numericArg.coerceToDouble(); uassert(28766, str::stream() << "$ln's argument must be a positive number, but is " << argDouble, @@ -2417,6 +2505,18 @@ Value ExpressionLog::evaluateInternal(Variables* vars) const { str::stream() << "$log's base must be numeric, not " << typeName(baseVal.getType()), baseVal.numeric()); + if (argVal.getType() == NumberDecimal || baseVal.getType() == NumberDecimal) { + Decimal128 argDecimal = argVal.coerceToDecimal(); + Decimal128 baseDecimal = baseVal.coerceToDecimal(); + + if (argDecimal.isGreater(Decimal128::kNormalizedZero) && + baseDecimal.isNotEqual(Decimal128(1)) && + baseDecimal.isGreater(Decimal128::kNormalizedZero)) { + return Value(argDecimal.logarithm(baseDecimal)); + } + // Fall through for error cases. + } + double argDouble = argVal.coerceToDouble(); double baseDouble = baseVal.coerceToDouble(); uassert(28758, @@ -2437,6 +2537,13 @@ const char* ExpressionLog::getOpName() const { /* ----------------------- ExpressionLog10 ---------------------------- */ Value ExpressionLog10::evaluateNumericArg(const Value& numericArg) const { + if (numericArg.getType() == NumberDecimal) { + Decimal128 argDecimal = numericArg.getDecimal(); + if (argDecimal.isGreater(Decimal128::kNormalizedZero)) + return Value(argDecimal.logarithm(Decimal128(10))); + // Fall through for error case. + } + double argDouble = numericArg.coerceToDouble(); uassert(28761, str::stream() << "$log10's argument must be a positive number, but is " << argDouble, @@ -2672,15 +2779,24 @@ Value ExpressionPow::evaluateInternal(Variables* vars) const { str::stream() << "$pow's exponent must be numeric, not " << typeName(expType), expVal.numeric()); + auto checkNonZeroAndNeg = [](bool isZeroAndNeg) { + uassert(28764, "$pow cannot take a base of 0 and a negative exponent", !isZeroAndNeg); + }; + + // If either argument is decimal, return a decimal. + if (baseType == NumberDecimal || expType == NumberDecimal) { + Decimal128 baseDecimal = baseVal.coerceToDecimal(); + Decimal128 expDecimal = expVal.coerceToDecimal(); + checkNonZeroAndNeg(baseDecimal.isZero() && expDecimal.isNegative()); + return Value(baseDecimal.power(expDecimal)); + } + // pow() will cast args to doubles. double baseDouble = baseVal.coerceToDouble(); double expDouble = expVal.coerceToDouble(); + checkNonZeroAndNeg(baseDouble == 0 && expDouble < 0); - uassert(28764, - "$pow cannot take a base of 0 and a negative exponent", - !(baseDouble == 0 && expDouble < 0)); - - // If either number is a double, return a double. + // If either argument is a double, return a double. if (baseType == NumberDouble || expType == NumberDouble) { return Value(std::pow(baseDouble, expDouble)); } @@ -3414,10 +3530,17 @@ const char* ExpressionSplit::getOpName() const { /* ----------------------- ExpressionSqrt ---------------------------- */ Value ExpressionSqrt::evaluateNumericArg(const Value& numericArg) const { + auto checkArg = [](bool nonNegative) { + uassert(28714, "$sqrt's argument must be greater than or equal to 0", nonNegative); + }; + + if (numericArg.getType() == NumberDecimal) { + Decimal128 argDec = numericArg.getDecimal(); + checkArg(!argDec.isLess(Decimal128::kNormalizedZero)); // NaN returns Nan without error + return Value(argDec.squareRoot()); + } double argDouble = numericArg.coerceToDouble(); - uassert(28714, - "$sqrt's argument must be greater than or equal to 0", - argDouble >= 0 || std::isnan(argDouble)); + checkArg(!(argDouble < 0)); // NaN returns Nan without error return Value(sqrt(argDouble)); } @@ -3638,7 +3761,11 @@ Value ExpressionSubtract::evaluateInternal(Variables* vars) const { BSONType diffType = Value::getWidestNumeric(rhs.getType(), lhs.getType()); - if (diffType == NumberDouble) { + if (diffType == NumberDecimal) { + Decimal128 right = rhs.coerceToDecimal(); + Decimal128 left = lhs.coerceToDecimal(); + return Value(left.subtract(right)); + } else if (diffType == NumberDouble) { double right = rhs.coerceToDouble(); double left = lhs.coerceToDouble(); return Value(left - right); @@ -3838,8 +3965,15 @@ const char* ExpressionToUpper::getOpName() const { Value ExpressionTrunc::evaluateNumericArg(const Value& numericArg) const { // There's no point in truncating integers or longs, it will have no effect. - return numericArg.getType() == NumberDouble ? Value(std::trunc(numericArg.getDouble())) - : numericArg; + switch (numericArg.getType()) { + case NumberDecimal: + return Value(numericArg.getDecimal().quantize(Decimal128::kNormalizedZero, + Decimal128::kRoundTowardZero)); + case NumberDouble: + return Value(std::trunc(numericArg.getDouble())); + default: + return numericArg; + } } REGISTER_EXPRESSION(trunc, ExpressionTrunc::parse); diff --git a/src/mongo/db/pipeline/expression_test.cpp b/src/mongo/db/pipeline/expression_test.cpp index 2487bfc18f1..a72d977ee27 100644 --- a/src/mongo/db/pipeline/expression_test.cpp +++ b/src/mongo/db/pipeline/expression_test.cpp @@ -693,7 +693,7 @@ TEST_F(ExpressionCeilTest, LongArg) { Value(numeric_limits::max())); } -TEST_F(ExpressionCeilTest, FloatArg) { +TEST_F(ExpressionCeilTest, DoubleArg) { assertEvaluates(Value(2.0), Value(2.0)); assertEvaluates(Value(-2.0), Value(-2.0)); assertEvaluates(Value(0.9), Value(1.0)); @@ -709,6 +709,20 @@ TEST_F(ExpressionCeilTest, FloatArg) { assertEvaluates(Value(smallerThanLong), Value(smallerThanLong)); } +TEST_F(ExpressionCeilTest, DecimalArg) { + assertEvaluates(Value(Decimal128("2")), Value(Decimal128("2.0"))); + assertEvaluates(Value(Decimal128("-2")), Value(Decimal128("-2.0"))); + assertEvaluates(Value(Decimal128("0.9")), Value(Decimal128("1.0"))); + assertEvaluates(Value(Decimal128("0.1")), Value(Decimal128("1.0"))); + assertEvaluates(Value(Decimal128("-1.2")), Value(Decimal128("-1.0"))); + assertEvaluates(Value(Decimal128("-1.7")), Value(Decimal128("-1.0"))); + assertEvaluates(Value(Decimal128("1234567889.000000000000000000000001")), + Value(Decimal128("1234567890"))); + assertEvaluates(Value(Decimal128("-99999999999999999999999999999.99")), + Value(Decimal128("-99999999999999999999999999999.00"))); + assertEvaluates(Value(Decimal128("3.4E-6000")), Value(Decimal128("1"))); +} + TEST_F(ExpressionCeilTest, NullArg) { assertEvaluates(Value(BSONNULL), Value(BSONNULL)); } @@ -737,7 +751,7 @@ TEST_F(ExpressionFloorTest, LongArg) { Value(numeric_limits::max())); } -TEST_F(ExpressionFloorTest, FloatArg) { +TEST_F(ExpressionFloorTest, DoubleArg) { assertEvaluates(Value(2.0), Value(2.0)); assertEvaluates(Value(-2.0), Value(-2.0)); assertEvaluates(Value(0.9), Value(0.0)); @@ -753,6 +767,20 @@ TEST_F(ExpressionFloorTest, FloatArg) { assertEvaluates(Value(smallerThanLong), Value(smallerThanLong)); } +TEST_F(ExpressionFloorTest, DecimalArg) { + assertEvaluates(Value(Decimal128("2")), Value(Decimal128("2.0"))); + assertEvaluates(Value(Decimal128("-2")), Value(Decimal128("-2.0"))); + assertEvaluates(Value(Decimal128("0.9")), Value(Decimal128("0.0"))); + assertEvaluates(Value(Decimal128("0.1")), Value(Decimal128("0.0"))); + assertEvaluates(Value(Decimal128("-1.2")), Value(Decimal128("-2.0"))); + assertEvaluates(Value(Decimal128("-1.7")), Value(Decimal128("-2.0"))); + assertEvaluates(Value(Decimal128("1234567890.000000000000000000000001")), + Value(Decimal128("1234567890"))); + assertEvaluates(Value(Decimal128("-99999999999999999999999999999.99")), + Value(Decimal128("-100000000000000000000000000000"))); + assertEvaluates(Value(Decimal128("3.4E-6000")), Value(Decimal128("0"))); +} + TEST_F(ExpressionFloorTest, NullArg) { assertEvaluates(Value(BSONNULL), Value(BSONNULL)); } @@ -838,7 +866,7 @@ TEST_F(ExpressionTruncTest, LongArg) { Value(numeric_limits::max())); } -TEST_F(ExpressionTruncTest, FloatArg) { +TEST_F(ExpressionTruncTest, DoubleArg) { assertEvaluates(Value(2.0), Value(2.0)); assertEvaluates(Value(-2.0), Value(-2.0)); assertEvaluates(Value(0.9), Value(0.0)); @@ -854,6 +882,20 @@ TEST_F(ExpressionTruncTest, FloatArg) { assertEvaluates(Value(smallerThanLong), Value(smallerThanLong)); } +TEST_F(ExpressionTruncTest, DecimalArg) { + assertEvaluates(Value(Decimal128("2")), Value(Decimal128("2.0"))); + assertEvaluates(Value(Decimal128("-2")), Value(Decimal128("-2.0"))); + assertEvaluates(Value(Decimal128("0.9")), Value(Decimal128("0.0"))); + assertEvaluates(Value(Decimal128("0.1")), Value(Decimal128("0.0"))); + assertEvaluates(Value(Decimal128("-1.2")), Value(Decimal128("-1.0"))); + assertEvaluates(Value(Decimal128("-1.7")), Value(Decimal128("-1.0"))); + assertEvaluates(Value(Decimal128("123456789.9999999999999999999999999")), + Value(Decimal128("123456789"))); + assertEvaluates(Value(Decimal128("-99999999999999999999999999999.99")), + Value(Decimal128("-99999999999999999999999999999.00"))); + assertEvaluates(Value(Decimal128("3.4E-6000")), Value(Decimal128("0"))); +} + TEST_F(ExpressionTruncTest, NullArg) { assertEvaluates(Value(BSONNULL), Value(BSONNULL)); } @@ -1038,8 +1080,8 @@ class IntLong : public TwoOperandBase { } }; -/** Adding an int and a long overflows. */ -class IntLongOverflow : public TwoOperandBase { +/** Adding an int and a long produces a double. */ +class IntLongOverflowToDouble : public TwoOperandBase { BSONObj operand1() { return BSON("" << numeric_limits::max()); } @@ -1047,11 +1089,10 @@ class IntLongOverflow : public TwoOperandBase { return BSON("" << numeric_limits::max()); } BSONObj expectedResult() { - // Aggregation currently treats signed integers as overflowing like unsigned integers do. + // When the result cannot be represented in a NumberLong, a NumberDouble is returned. const auto im = numeric_limits::max(); const auto llm = numeric_limits::max(); - const auto result = static_cast(static_cast(im) + - static_cast(llm)); + double result = static_cast(im) + static_cast(llm); return BSON("" << result); } }; @@ -4896,7 +4937,7 @@ public: add(); add(); add(); - add(); + add(); add(); add(); add(); diff --git a/src/mongo/db/pipeline/value.cpp b/src/mongo/db/pipeline/value.cpp index 8a37c51067c..37b3b88afae 100644 --- a/src/mongo/db/pipeline/value.cpp +++ b/src/mongo/db/pipeline/value.cpp @@ -971,8 +971,6 @@ BSONType Value::getWidestNumeric(BSONType lType, BSONType rType) { return Undefined; } -// TODO: Add Decimal128 support to Value::integral() -// SERVER-19735 bool Value::integral() const { switch (getType()) { case NumberInt: @@ -984,6 +982,13 @@ bool Value::integral() const { return (_storage.doubleValue <= numeric_limits::max() && _storage.doubleValue >= numeric_limits::min() && _storage.doubleValue == static_cast(_storage.doubleValue)); + case NumberDecimal: { + // If we are able to convert the decimal to an int32_t without an rounding errors, + // then it is integral. + uint32_t signalingFlags = Decimal128::kNoFlag; + (void)_storage.getDecimal().toInt(&signalingFlags); + return signalingFlags == Decimal128::kNoFlag; + } default: return false; } diff --git a/src/mongo/db/pipeline/value.h b/src/mongo/db/pipeline/value.h index c6d4b90c0cd..5ac5c61180b 100644 --- a/src/mongo/db/pipeline/value.h +++ b/src/mongo/db/pipeline/value.h @@ -119,11 +119,9 @@ public: } /// true if type represents a number - // TODO: Add _storage.type == NumberDecimal - // SERVER-19735 bool numeric() const { return _storage.type == NumberDouble || _storage.type == NumberLong || - _storage.type == NumberInt; + _storage.type == NumberInt || _storage.type == NumberDecimal; } /** -- cgit v1.2.1