summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/mongo/db/pipeline/SConscript1
-rw-r--r--src/mongo/db/pipeline/accumulator.h17
-rw-r--r--src/mongo/db/pipeline/accumulator_avg.cpp80
-rw-r--r--src/mongo/db/pipeline/accumulator_sum.cpp105
-rw-r--r--src/mongo/db/pipeline/accumulator_test.cpp89
-rw-r--r--src/mongo/db/pipeline/expression.cpp252
-rw-r--r--src/mongo/db/pipeline/expression_test.cpp59
-rw-r--r--src/mongo/db/pipeline/value.cpp9
-rw-r--r--src/mongo/db/pipeline/value.h4
9 files changed, 455 insertions, 161 deletions
diff --git a/src/mongo/db/pipeline/SConscript b/src/mongo/db/pipeline/SConscript
index eef0b7ecada..8525401fb72 100644
--- a/src/mongo/db/pipeline/SConscript
+++ b/src/mongo/db/pipeline/SConscript
@@ -91,6 +91,7 @@ env.Library(
],
LIBDEPS=[
'document_value',
+ '$BUILD_DIR/mongo/util/summation',
'expression',
'field_path',
]
diff --git a/src/mongo/db/pipeline/accumulator.h b/src/mongo/db/pipeline/accumulator.h
index 7dd604b82b4..eb09459f942 100644
--- a/src/mongo/db/pipeline/accumulator.h
+++ b/src/mongo/db/pipeline/accumulator.h
@@ -38,6 +38,7 @@
#include "mongo/bson/bsontypes.h"
#include "mongo/db/pipeline/value.h"
#include "mongo/stdx/functional.h"
+#include "mongo/util/summation.h"
namespace mongo {
/**
@@ -194,9 +195,9 @@ public:
}
private:
- BSONType totalType;
- long long longTotal;
- double doubleTotal;
+ BSONType totalType = NumberInt;
+ DoubleDoubleSummation nonDecimalTotal;
+ Decimal128 decimalTotal;
};
@@ -268,7 +269,15 @@ public:
static boost::intrusive_ptr<Accumulator> create();
private:
- double _total;
+ /**
+ * The total of all values is partitioned between those that are decimals, and those that are
+ * not decimals, so the decimal total needs to add the non-decimal.
+ */
+ Decimal128 _getDecimalTotal() const;
+
+ bool _isDecimal;
+ DoubleDoubleSummation _nonDecimalTotal;
+ Decimal128 _decimalTotal;
long long _count;
};
diff --git a/src/mongo/db/pipeline/accumulator_avg.cpp b/src/mongo/db/pipeline/accumulator_avg.cpp
index ed11d81ecc0..f9faf8d359a 100644
--- a/src/mongo/db/pipeline/accumulator_avg.cpp
+++ b/src/mongo/db/pipeline/accumulator_avg.cpp
@@ -33,6 +33,7 @@
#include "mongo/db/pipeline/expression.h"
#include "mongo/db/pipeline/expression_context.h"
#include "mongo/db/pipeline/value.h"
+#include "mongo/platform/decimal128.h"
namespace mongo {
@@ -47,48 +48,85 @@ const char* AccumulatorAvg::getOpName() const {
namespace {
const char subTotalName[] = "subTotal";
+const char subTotalErrorName[] = "subTotalError"; // Used for extra precision
const char countName[] = "count";
-}
+} // namespace
void AccumulatorAvg::processInternal(const Value& input, bool merging) {
- if (!merging) {
- // non numeric types have no impact on average
- if (!input.numeric())
- return;
-
- _total += input.getDouble();
- _count += 1;
- } else {
- // We expect an object that contains both a subtotal and a count.
- // This is what getValue(true) produced below.
+ if (merging) {
+ // We expect an object that contains both a subtotal and a count. Additionally there may
+ // be an error value, that allows for additional precision.
+ // 'input' is what getValue(true) produced below.
verify(input.getType() == Object);
- _total += input[subTotalName].getDouble();
- _count += input[countName].getLong();
+ // We're recursively adding the subtotal to get the proper type treatment, but this only
+ // increments the count by one, so adjust the count afterwards. Similarly for 'error'.
+ processInternal(input[subTotalName], false);
+ _count += input[countName].getLong() - 1;
+ Value error = input[subTotalErrorName];
+ if (!error.missing()) {
+ processInternal(error, false);
+ _count--; // The error correction only adjusts the total, not the number of items.
+ }
+ return;
+ }
+
+ switch (input.getType()) {
+ case NumberDecimal:
+ _decimalTotal = _decimalTotal.add(input.getDecimal());
+ _isDecimal = true;
+ break;
+ case NumberLong:
+ // Avoid summation using double as that loses precision.
+ _nonDecimalTotal.addLong(input.getLong());
+ break;
+ case NumberInt:
+ case NumberDouble:
+ _nonDecimalTotal.addDouble(input.getDouble());
+ break;
+ default:
+ dassert(!input.numeric());
+ return;
}
+ _count++;
}
intrusive_ptr<Accumulator> AccumulatorAvg::create() {
return new AccumulatorAvg();
}
+Decimal128 AccumulatorAvg::_getDecimalTotal() const {
+ return _decimalTotal.add(_nonDecimalTotal.getDecimal());
+}
+
Value AccumulatorAvg::getValue(bool toBeMerged) const {
- if (!toBeMerged) {
- if (_count == 0)
- return Value(BSONNULL);
+ if (toBeMerged) {
+ if (_isDecimal)
+ return Value(Document{{subTotalName, _getDecimalTotal()}, {countName, _count}});
- return Value(_total / static_cast<double>(_count));
- } else {
- return Value(DOC(subTotalName << _total << countName << _count));
+ double total, error;
+ std::tie(total, error) = _nonDecimalTotal.getDoubleDouble();
+ return Value(
+ Document{{subTotalName, total}, {countName, _count}, {subTotalErrorName, error}});
}
+
+ if (_count == 0)
+ return Value(BSONNULL);
+
+ if (_isDecimal)
+ return Value(_getDecimalTotal().divide(Decimal128(static_cast<int64_t>(_count))));
+
+ return Value(_nonDecimalTotal.getDouble() / static_cast<double>(_count));
}
-AccumulatorAvg::AccumulatorAvg() : _total(0), _count(0) {
+AccumulatorAvg::AccumulatorAvg() : _isDecimal(false), _count(0) {
// This is a fixed size Accumulator so we never need to update this
_memUsageBytes = sizeof(*this);
}
void AccumulatorAvg::reset() {
- _total = 0;
+ _isDecimal = false;
+ _nonDecimalTotal = {};
+ _decimalTotal = {};
_count = 0;
}
}
diff --git a/src/mongo/db/pipeline/accumulator_sum.cpp b/src/mongo/db/pipeline/accumulator_sum.cpp
index c064fe52f04..1255c3fde8b 100644
--- a/src/mongo/db/pipeline/accumulator_sum.cpp
+++ b/src/mongo/db/pipeline/accumulator_sum.cpp
@@ -28,9 +28,13 @@
#include "mongo/platform/basic.h"
+#include <cmath>
+#include <limits>
+
#include "mongo/db/pipeline/accumulator.h"
#include "mongo/db/pipeline/expression.h"
#include "mongo/db/pipeline/value.h"
+#include "mongo/util/summation.h"
namespace mongo {
@@ -43,24 +47,38 @@ const char* AccumulatorSum::getOpName() const {
return "$sum";
}
+namespace {
+const char subTotalName[] = "subTotal";
+const char subTotalErrorName[] = "subTotalError"; // Used for extra precision.
+} // namespace
+
+
void AccumulatorSum::processInternal(const Value& input, bool merging) {
- // do nothing with non numeric types
- if (!input.numeric())
+ if (!input.numeric()) {
+ if (merging && input.getType() == Object) {
+ // Process merge document, see getValue() below.
+ nonDecimalTotal.addDouble(
+ input[subTotalName].getDouble()); // Sum without adjusting type.
+ processInternal(input[subTotalErrorName], false); // Sum adjusting for type of error.
+ }
return;
+ }
- // upgrade to the widest type required to hold the result
+ // Upgrade to the widest type required to hold the result.
totalType = Value::getWidestNumeric(totalType, input.getType());
-
- if (totalType == NumberInt || totalType == NumberLong) {
- long long v = input.coerceToLong();
- longTotal += v;
- doubleTotal += v;
- } else if (totalType == NumberDouble) {
- double v = input.coerceToDouble();
- doubleTotal += v;
- } else {
- // non numerics should have returned above so we should never get here
- verify(false);
+ switch (input.getType()) {
+ case NumberInt:
+ case NumberLong:
+ nonDecimalTotal.addLong(input.coerceToLong());
+ break;
+ case NumberDouble:
+ nonDecimalTotal.addDouble(input.getDouble());
+ break;
+ case NumberDecimal:
+ decimalTotal = decimalTotal.add(input.coerceToDecimal());
+ break;
+ default:
+ MONGO_UNREACHABLE;
}
}
@@ -69,25 +87,58 @@ intrusive_ptr<Accumulator> AccumulatorSum::create() {
}
Value AccumulatorSum::getValue(bool toBeMerged) const {
- if (totalType == NumberLong) {
- return Value(longTotal);
- } else if (totalType == NumberDouble) {
- return Value(doubleTotal);
- } else if (totalType == NumberInt) {
- return Value::createIntOrLong(longTotal);
- } else {
- massert(16000, "$sum resulted in a non-numeric type", false);
+ switch (totalType) {
+ case NumberInt:
+ if (nonDecimalTotal.fitsLong())
+ return Value::createIntOrLong(nonDecimalTotal.getLong());
+ // Fallthrough.
+ case NumberLong:
+ if (nonDecimalTotal.fitsLong())
+ return Value(nonDecimalTotal.getLong());
+ if (toBeMerged) {
+ // The value was too large for a NumberLong, so output a document with two values
+ // adding up to the desired total. Older MongoDB versions used to ignore signed
+ // integer overflow and cause undefined behavior, that in practice resulted in
+ // values that would wrap around modulo 2**64. Now an older mongos with a newer
+ // mongod will yield an error that $sum resulted in a non-numeric type, which is
+ // OK for this case. Output the error using the totalType, so in the future we can
+ // determine the correct totalType for the sum. For the error to exceed 2**63,
+ // more than 2**53 integers would have to be summed, which is impossible.
+ double total;
+ double error;
+ std::tie(total, error) = nonDecimalTotal.getDoubleDouble();
+ long long llerror = static_cast<long long>(error);
+ return Value(DOC(subTotalName << total << subTotalErrorName << llerror));
+ }
+ // Sum doesn't fit a NumberLong, so return a NumberDouble instead.
+ return Value(nonDecimalTotal.getDouble());
+
+ case NumberDouble:
+ return Value(nonDecimalTotal.getDouble());
+ case NumberDecimal: {
+ double sum, error;
+ std::tie(sum, error) = nonDecimalTotal.getDoubleDouble();
+ Decimal128 total; // zero
+ if (sum != 0) {
+ total = total.add(Decimal128(sum, Decimal128::kRoundTo34Digits));
+ total = total.add(Decimal128(error, Decimal128::kRoundTo34Digits));
+ }
+ total = total.add(decimalTotal);
+ return Value(total);
+ }
+ default:
+ MONGO_UNREACHABLE;
}
}
-AccumulatorSum::AccumulatorSum() : totalType(NumberInt), longTotal(0), doubleTotal(0) {
- // This is a fixed size Accumulator so we never need to update this
+AccumulatorSum::AccumulatorSum() {
+ // This is a fixed size Accumulator so we never need to update this.
_memUsageBytes = sizeof(*this);
}
void AccumulatorSum::reset() {
totalType = NumberInt;
- longTotal = 0;
- doubleTotal = 0;
-}
+ nonDecimalTotal = {};
+ decimalTotal = {};
}
+} // namespace mongo
diff --git a/src/mongo/db/pipeline/accumulator_test.cpp b/src/mongo/db/pipeline/accumulator_test.cpp
index 928eb2e868e..e354d71d8f9 100644
--- a/src/mongo/db/pipeline/accumulator_test.cpp
+++ b/src/mongo/db/pipeline/accumulator_test.cpp
@@ -96,38 +96,55 @@ static void assertExpectedResults(
TEST(Accumulators, Avg) {
assertExpectedResults(
"$avg",
- {// No documents evaluated.
- {{}, Value(BSONNULL)},
-
- // One int value is converted to double.
- {{Value(3)}, Value(3.0)},
- // One long value is converted to double.
- {{Value(-4LL)}, Value(-4.0)},
- // One double value.
- {{Value(22.6)}, Value(22.6)},
-
- // Averaging two ints.
- {{Value(10), Value(11)}, Value(10.5)},
- // Averaging two longs.
- {{Value(10LL), Value(11LL)}, Value(10.5)},
- // Averaging two doubles.
- {{Value(10.0), Value(11.0)}, Value(10.5)},
-
- // The average of an int and a double is a double.
- {{Value(10), Value(11.0)}, Value(10.5)},
- // The average of a long and a double is a double.
- {{Value(5LL), Value(1.0)}, Value(3.0)},
- // The average of an int and a long is a double.
- {{Value(5), Value(3LL)}, Value(4.0)},
- // Averaging an int, long, and double.
- {{Value(1), Value(2LL), Value(6.0)}, Value(3.0)},
-
- // Unlike $sum, two ints do not overflow in the 'total' portion of the average.
- {{Value(numeric_limits<int>::max()), Value(numeric_limits<int>::max())},
- Value(static_cast<double>(numeric_limits<int>::max()))},
- // Two longs do overflow in the 'total' portion of the average.
- {{Value(numeric_limits<long long>::max()), Value(numeric_limits<long long>::max())},
- Value(static_cast<double>(numeric_limits<long long>::max()))}});
+ {
+ // No documents evaluated.
+ {{}, Value(BSONNULL)},
+
+ // One int value is converted to double.
+ {{Value(3)}, Value(3.0)},
+ // One long value is converted to double.
+ {{Value(-4LL)}, Value(-4.0)},
+ // One double value.
+ {{Value(22.6)}, Value(22.6)},
+
+ // Averaging two ints.
+ {{Value(10), Value(11)}, Value(10.5)},
+ // Averaging two longs.
+ {{Value(10LL), Value(11LL)}, Value(10.5)},
+ // Averaging two doubles.
+ {{Value(10.0), Value(11.0)}, Value(10.5)},
+
+ // The average of an int and a double is a double.
+ {{Value(10), Value(11.0)}, Value(10.5)},
+ // The average of a long and a double is a double.
+ {{Value(5LL), Value(1.0)}, Value(3.0)},
+ // The average of an int and a long is a double.
+ {{Value(5), Value(3LL)}, Value(4.0)},
+ // Averaging an int, long, and double.
+ {{Value(1), Value(2LL), Value(6.0)}, Value(3.0)},
+
+ // Unlike $sum, two ints do not overflow in the 'total' portion of the average.
+ {{Value(numeric_limits<int>::max()), Value(numeric_limits<int>::max())},
+ Value(static_cast<double>(numeric_limits<int>::max()))},
+ // Two longs do overflow in the 'total' portion of the average.
+ {{Value(numeric_limits<long long>::max()), Value(numeric_limits<long long>::max())},
+ Value(static_cast<double>(numeric_limits<long long>::max()))},
+
+ // Averaging two decimals.
+ {{Value(Decimal128("-1234567890.1234567889")),
+ Value(Decimal128("-1234567890.1234567891"))},
+ Value(Decimal128("-1234567890.1234567890"))},
+
+ // Averaging two longs and a decimal results in an accurate decimal result.
+ {{Value(1234567890123456788LL),
+ Value(1234567890123456789LL),
+ Value(Decimal128("1234567890123456790.037037036703702"))},
+ Value(Decimal128("1234567890123456789.012345678901234"))},
+
+ // Averaging a double and a decimal
+ {{Value(1.0E22), Value(Decimal128("9999999999999999999999.9999999999"))},
+ Value(Decimal128("9999999999999999999999.99999999995"))},
+ });
}
TEST(Accumulators, First) {
@@ -249,12 +266,12 @@ TEST(Accumulators, Sum) {
// An int and a double do not trigger an int overflow.
{{Value(numeric_limits<int>::max()), Value(1.0)},
Value(static_cast<long long>(numeric_limits<int>::max()) + 1.0)},
- // An int and a long overflow.
+ // An int and a long overflow into a double.
{{Value(1), Value(numeric_limits<long long>::max())},
- Value(numeric_limits<long long>::min())},
- // Two longs overflow.
+ Value(-static_cast<double>(numeric_limits<long long>::min()))},
+ // Two longs overflow into a double.
{{Value(numeric_limits<long long>::max()), Value(numeric_limits<long long>::max())},
- Value(-2LL)},
+ Value(static_cast<double>(numeric_limits<long long>::max()) * 2)},
// A long and a double do not trigger a long overflow.
{{Value(numeric_limits<long long>::max()), Value(1.0)},
Value(numeric_limits<long long>::max() + 1.0)},
diff --git a/src/mongo/db/pipeline/expression.cpp b/src/mongo/db/pipeline/expression.cpp
index 9c68924d476..1c976df8f42 100644
--- a/src/mongo/db/pipeline/expression.cpp
+++ b/src/mongo/db/pipeline/expression.cpp
@@ -41,8 +41,10 @@
#include "mongo/db/pipeline/expression_context.h"
#include "mongo/db/pipeline/value.h"
#include "mongo/platform/bits.h"
+#include "mongo/platform/decimal128.h"
#include "mongo/util/mongoutils/str.h"
#include "mongo/util/string_map.h"
+#include "mongo/util/summation.h"
namespace mongo {
using Parser = Expression::Parser;
@@ -414,6 +416,8 @@ Value ExpressionAbs::evaluateNumericArg(const Value& numericArg) const {
BSONType type = numericArg.getType();
if (type == NumberDouble) {
return Value(std::abs(numericArg.getDouble()));
+ } else if (type == NumberDecimal) {
+ return Value(numericArg.getDecimal().toAbs());
} else {
long long num = numericArg.getLong();
uassert(28680,
@@ -432,14 +436,12 @@ const char* ExpressionAbs::getOpName() const {
/* ------------------------- ExpressionAdd ----------------------------- */
Value ExpressionAdd::evaluateInternal(Variables* vars) const {
- /*
- We'll try to return the narrowest possible result value. To do that
- without creating intermediate Values, do the arithmetic for double
- and integral types in parallel, tracking the current narrowest
- type.
- */
- double doubleTotal = 0;
- long long longTotal = 0;
+ // We'll try to return the narrowest possible result value while avoiding overflow, loss
+ // of precision due to intermediate rounding or implicit use of decimal types. To do that,
+ // compute a compensated sum for non-decimal values and a separate decimal sum for decimal
+ // values, and track the current narrowest type.
+ DoubleDoubleSummation nonDecimalTotal;
+ Decimal128 decimalTotal;
BSONType totalType = NumberInt;
bool haveDate = false;
@@ -447,40 +449,64 @@ Value ExpressionAdd::evaluateInternal(Variables* vars) const {
for (size_t i = 0; i < n; ++i) {
Value val = vpOperand[i]->evaluateInternal(vars);
- if (val.numeric()) {
- totalType = Value::getWidestNumeric(totalType, val.getType());
-
- doubleTotal += val.coerceToDouble();
- longTotal += val.coerceToLong();
- } else if (val.getType() == Date) {
- uassert(16612, "only one date allowed in an $add expression", !haveDate);
- haveDate = true;
-
- // We don't manipulate totalType here.
-
- longTotal += val.getDate();
- doubleTotal += val.getDate();
- } else if (val.nullish()) {
- return Value(BSONNULL);
- } else {
- uasserted(16554,
- str::stream() << "$add only supports numeric or date types, not "
- << typeName(val.getType()));
+ switch (val.getType()) {
+ case NumberDecimal:
+ decimalTotal = decimalTotal.add(val.getDecimal());
+ totalType = NumberDecimal;
+ break;
+ case NumberDouble:
+ nonDecimalTotal.addDouble(val.getDouble());
+ if (totalType != NumberDecimal)
+ totalType = NumberDouble;
+ break;
+ case NumberLong:
+ nonDecimalTotal.addLong(val.getLong());
+ if (totalType == NumberInt)
+ totalType = NumberLong;
+ break;
+ case NumberInt:
+ nonDecimalTotal.addDouble(val.getInt());
+ break;
+ case Date:
+ uassert(16612, "only one date allowed in an $add expression", !haveDate);
+ haveDate = true;
+ nonDecimalTotal.addLong(val.getDate());
+ break;
+ default:
+ uassert(16554,
+ str::stream() << "$add only supports numeric or date types, not "
+ << typeName(val.getType()),
+ val.nullish());
+ return Value(BSONNULL);
}
}
if (haveDate) {
- if (totalType == NumberDouble)
- longTotal = static_cast<long long>(doubleTotal);
+ int64_t longTotal;
+ if (totalType == NumberDecimal) {
+ longTotal = decimalTotal.add(nonDecimalTotal.getDecimal()).toLong();
+ } else {
+ uassert(ErrorCodes::Overflow, "date overflow in $add", nonDecimalTotal.fitsLong());
+ longTotal = nonDecimalTotal.getLong();
+ }
return Value(Date_t::fromMillisSinceEpoch(longTotal));
- } else if (totalType == NumberLong) {
- return Value(longTotal);
- } else if (totalType == NumberDouble) {
- return Value(doubleTotal);
- } else if (totalType == NumberInt) {
- return Value::createIntOrLong(longTotal);
- } else {
- massert(16417, "$add resulted in a non-numeric type", false);
+ }
+ switch (totalType) {
+ case NumberDecimal:
+ return Value(decimalTotal.add(nonDecimalTotal.getDecimal()));
+ case NumberLong:
+ dassert(nonDecimalTotal.isInteger());
+ if (nonDecimalTotal.fitsLong())
+ return Value(nonDecimalTotal.getLong());
+ // Fallthrough.
+ case NumberInt:
+ if (nonDecimalTotal.fitsLong())
+ return Value::createIntOrLong(nonDecimalTotal.getLong());
+ // Fallthrough.
+ case NumberDouble:
+ return Value(nonDecimalTotal.getDouble());
+ default:
+ massert(16417, "$add resulted in a non-numeric type", false);
}
}
@@ -677,8 +703,16 @@ const char* ExpressionArrayElemAt::getOpName() const {
Value ExpressionCeil::evaluateNumericArg(const Value& numericArg) const {
// There's no point in taking the ceiling of integers or longs, it will have no effect.
- return numericArg.getType() == NumberDouble ? Value(std::ceil(numericArg.getDouble()))
- : numericArg;
+ switch (numericArg.getType()) {
+ case NumberDouble:
+ return Value(std::ceil(numericArg.getDouble()));
+ case NumberDecimal:
+ // Round toward the nearest decimal with a zero exponent in the positive direction.
+ return Value(numericArg.getDecimal().quantize(Decimal128::kNormalizedZero,
+ Decimal128::kRoundTowardPositive));
+ default:
+ return numericArg;
+ }
}
REGISTER_EXPRESSION(ceil, ExpressionCeil::parse);
@@ -1203,10 +1237,20 @@ Value ExpressionDivide::evaluateInternal(Variables* vars) const {
Value lhs = vpOperand[0]->evaluateInternal(vars);
Value rhs = vpOperand[1]->evaluateInternal(vars);
+ auto assertNonZero = [](bool nonZero) { uassert(16608, "can't $divide by zero", nonZero); };
+
if (lhs.numeric() && rhs.numeric()) {
+ // If, and only if, either side is decimal, return decimal.
+ if (lhs.getType() == NumberDecimal || rhs.getType() == NumberDecimal) {
+ Decimal128 numer = lhs.coerceToDecimal();
+ Decimal128 denom = rhs.coerceToDecimal();
+ assertNonZero(!denom.isZero());
+ return Value(numer.divide(denom));
+ }
+
double numer = lhs.coerceToDouble();
double denom = rhs.coerceToDouble();
- uassert(16608, "can't $divide by zero", denom != 0);
+ assertNonZero(denom != 0.0);
return Value(numer / denom);
} else if (lhs.nullish() || rhs.nullish()) {
@@ -1228,7 +1272,10 @@ const char* ExpressionDivide::getOpName() const {
/* ----------------------- ExpressionExp ---------------------------- */
Value ExpressionExp::evaluateNumericArg(const Value& numericArg) const {
- // exp() always returns a double since e is a double.
+ // $exp always returns either a double or a decimal number, as e is irrational.
+ if (numericArg.getType() == NumberDecimal)
+ return Value(numericArg.coerceToDecimal().exponential());
+
return Value(exp(numericArg.coerceToDouble()));
}
@@ -1740,8 +1787,16 @@ void ExpressionFilter::addDependencies(DepsTracker* deps, vector<string>* path)
Value ExpressionFloor::evaluateNumericArg(const Value& numericArg) const {
// There's no point in taking the floor of integers or longs, it will have no effect.
- return numericArg.getType() == NumberDouble ? Value(std::floor(numericArg.getDouble()))
- : numericArg;
+ switch (numericArg.getType()) {
+ case NumberDouble:
+ return Value(std::floor(numericArg.getDouble()));
+ case NumberDecimal:
+ // Round toward the nearest decimal with a zero exponent in the negative direction.
+ return Value(numericArg.getDecimal().quantize(Decimal128::kNormalizedZero,
+ Decimal128::kRoundTowardNegative));
+ default:
+ return numericArg;
+ }
}
REGISTER_EXPRESSION(floor, ExpressionFloor::parse);
@@ -2028,10 +2083,20 @@ Value ExpressionMod::evaluateInternal(Variables* vars) const {
BSONType rightType = rhs.getType();
if (lhs.numeric() && rhs.numeric()) {
+ auto assertNonZero = [](bool isZero) { uassert(16610, "can't $mod by zero", !isZero); };
+
+ // If either side is decimal, perform the operation in decimal.
+ if (leftType == NumberDecimal || rightType == NumberDecimal) {
+ Decimal128 left = lhs.coerceToDecimal();
+ Decimal128 right = rhs.coerceToDecimal();
+ assertNonZero(right.isZero());
+
+ return Value(left.modulo(right));
+ }
+
// ensure we aren't modding by 0
double right = rhs.coerceToDouble();
-
- uassert(16610, "can't $mod by 0", right != 0);
+ assertNonZero(right == 0);
if (leftType == NumberDouble || (rightType == NumberDouble && !rhs.integral())) {
// Need to do fmod. Integer-valued double case is handled below.
@@ -2088,6 +2153,8 @@ Value ExpressionMultiply::evaluateInternal(Variables* vars) const {
*/
double doubleProduct = 1;
long long longProduct = 1;
+ Decimal128 decimalProduct; // This will be initialized on encountering the first decimal.
+
BSONType productType = NumberInt;
const size_t n = vpOperand.size();
@@ -2095,10 +2162,23 @@ Value ExpressionMultiply::evaluateInternal(Variables* vars) const {
Value val = vpOperand[i]->evaluateInternal(vars);
if (val.numeric()) {
+ BSONType oldProductType = productType;
productType = Value::getWidestNumeric(productType, val.getType());
-
- doubleProduct *= val.coerceToDouble();
- longProduct *= val.coerceToLong();
+ if (productType == NumberDecimal) {
+ // On finding the first decimal, convert the partial product to decimal.
+ if (oldProductType != NumberDecimal) {
+ decimalProduct = oldProductType == NumberDouble
+ ? Decimal128(doubleProduct, Decimal128::kRoundTo15Digits)
+ : Decimal128(static_cast<int64_t>(longProduct));
+ }
+ decimalProduct = decimalProduct.multiply(val.coerceToDecimal());
+ } else {
+ doubleProduct *= val.coerceToDouble();
+ if (mongoSignedMultiplyOverflow64(longProduct, val.coerceToLong(), &longProduct)) {
+ // The 'longProduct' would have overflowed, so we're abandoning it.
+ productType = NumberDouble;
+ }
+ }
} else if (val.nullish()) {
return Value(BSONNULL);
} else {
@@ -2114,6 +2194,8 @@ Value ExpressionMultiply::evaluateInternal(Variables* vars) const {
return Value(longProduct);
else if (productType == NumberInt)
return Value::createIntOrLong(longProduct);
+ else if (productType == NumberDecimal)
+ return Value(decimalProduct);
else
massert(16418, "$multiply resulted in a non-numeric type", false);
}
@@ -2390,6 +2472,12 @@ const char* ExpressionIndexOfCP::getOpName() const {
/* ----------------------- ExpressionLn ---------------------------- */
Value ExpressionLn::evaluateNumericArg(const Value& numericArg) const {
+ if (numericArg.getType() == NumberDecimal) {
+ Decimal128 argDecimal = numericArg.getDecimal();
+ if (argDecimal.isGreater(Decimal128::kNormalizedZero))
+ return Value(argDecimal.logarithm());
+ // Fall through for error case.
+ }
double argDouble = numericArg.coerceToDouble();
uassert(28766,
str::stream() << "$ln's argument must be a positive number, but is " << argDouble,
@@ -2417,6 +2505,18 @@ Value ExpressionLog::evaluateInternal(Variables* vars) const {
str::stream() << "$log's base must be numeric, not " << typeName(baseVal.getType()),
baseVal.numeric());
+ if (argVal.getType() == NumberDecimal || baseVal.getType() == NumberDecimal) {
+ Decimal128 argDecimal = argVal.coerceToDecimal();
+ Decimal128 baseDecimal = baseVal.coerceToDecimal();
+
+ if (argDecimal.isGreater(Decimal128::kNormalizedZero) &&
+ baseDecimal.isNotEqual(Decimal128(1)) &&
+ baseDecimal.isGreater(Decimal128::kNormalizedZero)) {
+ return Value(argDecimal.logarithm(baseDecimal));
+ }
+ // Fall through for error cases.
+ }
+
double argDouble = argVal.coerceToDouble();
double baseDouble = baseVal.coerceToDouble();
uassert(28758,
@@ -2437,6 +2537,13 @@ const char* ExpressionLog::getOpName() const {
/* ----------------------- ExpressionLog10 ---------------------------- */
Value ExpressionLog10::evaluateNumericArg(const Value& numericArg) const {
+ if (numericArg.getType() == NumberDecimal) {
+ Decimal128 argDecimal = numericArg.getDecimal();
+ if (argDecimal.isGreater(Decimal128::kNormalizedZero))
+ return Value(argDecimal.logarithm(Decimal128(10)));
+ // Fall through for error case.
+ }
+
double argDouble = numericArg.coerceToDouble();
uassert(28761,
str::stream() << "$log10's argument must be a positive number, but is " << argDouble,
@@ -2672,15 +2779,24 @@ Value ExpressionPow::evaluateInternal(Variables* vars) const {
str::stream() << "$pow's exponent must be numeric, not " << typeName(expType),
expVal.numeric());
+ auto checkNonZeroAndNeg = [](bool isZeroAndNeg) {
+ uassert(28764, "$pow cannot take a base of 0 and a negative exponent", !isZeroAndNeg);
+ };
+
+ // If either argument is decimal, return a decimal.
+ if (baseType == NumberDecimal || expType == NumberDecimal) {
+ Decimal128 baseDecimal = baseVal.coerceToDecimal();
+ Decimal128 expDecimal = expVal.coerceToDecimal();
+ checkNonZeroAndNeg(baseDecimal.isZero() && expDecimal.isNegative());
+ return Value(baseDecimal.power(expDecimal));
+ }
+
// pow() will cast args to doubles.
double baseDouble = baseVal.coerceToDouble();
double expDouble = expVal.coerceToDouble();
+ checkNonZeroAndNeg(baseDouble == 0 && expDouble < 0);
- uassert(28764,
- "$pow cannot take a base of 0 and a negative exponent",
- !(baseDouble == 0 && expDouble < 0));
-
- // If either number is a double, return a double.
+ // If either argument is a double, return a double.
if (baseType == NumberDouble || expType == NumberDouble) {
return Value(std::pow(baseDouble, expDouble));
}
@@ -3414,10 +3530,17 @@ const char* ExpressionSplit::getOpName() const {
/* ----------------------- ExpressionSqrt ---------------------------- */
Value ExpressionSqrt::evaluateNumericArg(const Value& numericArg) const {
+ auto checkArg = [](bool nonNegative) {
+ uassert(28714, "$sqrt's argument must be greater than or equal to 0", nonNegative);
+ };
+
+ if (numericArg.getType() == NumberDecimal) {
+ Decimal128 argDec = numericArg.getDecimal();
+ checkArg(!argDec.isLess(Decimal128::kNormalizedZero)); // NaN returns Nan without error
+ return Value(argDec.squareRoot());
+ }
double argDouble = numericArg.coerceToDouble();
- uassert(28714,
- "$sqrt's argument must be greater than or equal to 0",
- argDouble >= 0 || std::isnan(argDouble));
+ checkArg(!(argDouble < 0)); // NaN returns Nan without error
return Value(sqrt(argDouble));
}
@@ -3638,7 +3761,11 @@ Value ExpressionSubtract::evaluateInternal(Variables* vars) const {
BSONType diffType = Value::getWidestNumeric(rhs.getType(), lhs.getType());
- if (diffType == NumberDouble) {
+ if (diffType == NumberDecimal) {
+ Decimal128 right = rhs.coerceToDecimal();
+ Decimal128 left = lhs.coerceToDecimal();
+ return Value(left.subtract(right));
+ } else if (diffType == NumberDouble) {
double right = rhs.coerceToDouble();
double left = lhs.coerceToDouble();
return Value(left - right);
@@ -3838,8 +3965,15 @@ const char* ExpressionToUpper::getOpName() const {
Value ExpressionTrunc::evaluateNumericArg(const Value& numericArg) const {
// There's no point in truncating integers or longs, it will have no effect.
- return numericArg.getType() == NumberDouble ? Value(std::trunc(numericArg.getDouble()))
- : numericArg;
+ switch (numericArg.getType()) {
+ case NumberDecimal:
+ return Value(numericArg.getDecimal().quantize(Decimal128::kNormalizedZero,
+ Decimal128::kRoundTowardZero));
+ case NumberDouble:
+ return Value(std::trunc(numericArg.getDouble()));
+ default:
+ return numericArg;
+ }
}
REGISTER_EXPRESSION(trunc, ExpressionTrunc::parse);
diff --git a/src/mongo/db/pipeline/expression_test.cpp b/src/mongo/db/pipeline/expression_test.cpp
index 2487bfc18f1..a72d977ee27 100644
--- a/src/mongo/db/pipeline/expression_test.cpp
+++ b/src/mongo/db/pipeline/expression_test.cpp
@@ -693,7 +693,7 @@ TEST_F(ExpressionCeilTest, LongArg) {
Value(numeric_limits<long long>::max()));
}
-TEST_F(ExpressionCeilTest, FloatArg) {
+TEST_F(ExpressionCeilTest, DoubleArg) {
assertEvaluates(Value(2.0), Value(2.0));
assertEvaluates(Value(-2.0), Value(-2.0));
assertEvaluates(Value(0.9), Value(1.0));
@@ -709,6 +709,20 @@ TEST_F(ExpressionCeilTest, FloatArg) {
assertEvaluates(Value(smallerThanLong), Value(smallerThanLong));
}
+TEST_F(ExpressionCeilTest, DecimalArg) {
+ assertEvaluates(Value(Decimal128("2")), Value(Decimal128("2.0")));
+ assertEvaluates(Value(Decimal128("-2")), Value(Decimal128("-2.0")));
+ assertEvaluates(Value(Decimal128("0.9")), Value(Decimal128("1.0")));
+ assertEvaluates(Value(Decimal128("0.1")), Value(Decimal128("1.0")));
+ assertEvaluates(Value(Decimal128("-1.2")), Value(Decimal128("-1.0")));
+ assertEvaluates(Value(Decimal128("-1.7")), Value(Decimal128("-1.0")));
+ assertEvaluates(Value(Decimal128("1234567889.000000000000000000000001")),
+ Value(Decimal128("1234567890")));
+ assertEvaluates(Value(Decimal128("-99999999999999999999999999999.99")),
+ Value(Decimal128("-99999999999999999999999999999.00")));
+ assertEvaluates(Value(Decimal128("3.4E-6000")), Value(Decimal128("1")));
+}
+
TEST_F(ExpressionCeilTest, NullArg) {
assertEvaluates(Value(BSONNULL), Value(BSONNULL));
}
@@ -737,7 +751,7 @@ TEST_F(ExpressionFloorTest, LongArg) {
Value(numeric_limits<long long>::max()));
}
-TEST_F(ExpressionFloorTest, FloatArg) {
+TEST_F(ExpressionFloorTest, DoubleArg) {
assertEvaluates(Value(2.0), Value(2.0));
assertEvaluates(Value(-2.0), Value(-2.0));
assertEvaluates(Value(0.9), Value(0.0));
@@ -753,6 +767,20 @@ TEST_F(ExpressionFloorTest, FloatArg) {
assertEvaluates(Value(smallerThanLong), Value(smallerThanLong));
}
+TEST_F(ExpressionFloorTest, DecimalArg) {
+ assertEvaluates(Value(Decimal128("2")), Value(Decimal128("2.0")));
+ assertEvaluates(Value(Decimal128("-2")), Value(Decimal128("-2.0")));
+ assertEvaluates(Value(Decimal128("0.9")), Value(Decimal128("0.0")));
+ assertEvaluates(Value(Decimal128("0.1")), Value(Decimal128("0.0")));
+ assertEvaluates(Value(Decimal128("-1.2")), Value(Decimal128("-2.0")));
+ assertEvaluates(Value(Decimal128("-1.7")), Value(Decimal128("-2.0")));
+ assertEvaluates(Value(Decimal128("1234567890.000000000000000000000001")),
+ Value(Decimal128("1234567890")));
+ assertEvaluates(Value(Decimal128("-99999999999999999999999999999.99")),
+ Value(Decimal128("-100000000000000000000000000000")));
+ assertEvaluates(Value(Decimal128("3.4E-6000")), Value(Decimal128("0")));
+}
+
TEST_F(ExpressionFloorTest, NullArg) {
assertEvaluates(Value(BSONNULL), Value(BSONNULL));
}
@@ -838,7 +866,7 @@ TEST_F(ExpressionTruncTest, LongArg) {
Value(numeric_limits<long long>::max()));
}
-TEST_F(ExpressionTruncTest, FloatArg) {
+TEST_F(ExpressionTruncTest, DoubleArg) {
assertEvaluates(Value(2.0), Value(2.0));
assertEvaluates(Value(-2.0), Value(-2.0));
assertEvaluates(Value(0.9), Value(0.0));
@@ -854,6 +882,20 @@ TEST_F(ExpressionTruncTest, FloatArg) {
assertEvaluates(Value(smallerThanLong), Value(smallerThanLong));
}
+TEST_F(ExpressionTruncTest, DecimalArg) {
+ assertEvaluates(Value(Decimal128("2")), Value(Decimal128("2.0")));
+ assertEvaluates(Value(Decimal128("-2")), Value(Decimal128("-2.0")));
+ assertEvaluates(Value(Decimal128("0.9")), Value(Decimal128("0.0")));
+ assertEvaluates(Value(Decimal128("0.1")), Value(Decimal128("0.0")));
+ assertEvaluates(Value(Decimal128("-1.2")), Value(Decimal128("-1.0")));
+ assertEvaluates(Value(Decimal128("-1.7")), Value(Decimal128("-1.0")));
+ assertEvaluates(Value(Decimal128("123456789.9999999999999999999999999")),
+ Value(Decimal128("123456789")));
+ assertEvaluates(Value(Decimal128("-99999999999999999999999999999.99")),
+ Value(Decimal128("-99999999999999999999999999999.00")));
+ assertEvaluates(Value(Decimal128("3.4E-6000")), Value(Decimal128("0")));
+}
+
TEST_F(ExpressionTruncTest, NullArg) {
assertEvaluates(Value(BSONNULL), Value(BSONNULL));
}
@@ -1038,8 +1080,8 @@ class IntLong : public TwoOperandBase {
}
};
-/** Adding an int and a long overflows. */
-class IntLongOverflow : public TwoOperandBase {
+/** Adding an int and a long produces a double. */
+class IntLongOverflowToDouble : public TwoOperandBase {
BSONObj operand1() {
return BSON("" << numeric_limits<int>::max());
}
@@ -1047,11 +1089,10 @@ class IntLongOverflow : public TwoOperandBase {
return BSON("" << numeric_limits<long long>::max());
}
BSONObj expectedResult() {
- // Aggregation currently treats signed integers as overflowing like unsigned integers do.
+ // When the result cannot be represented in a NumberLong, a NumberDouble is returned.
const auto im = numeric_limits<int>::max();
const auto llm = numeric_limits<long long>::max();
- const auto result = static_cast<long long>(static_cast<unsigned int>(im) +
- static_cast<unsigned long long>(llm));
+ double result = static_cast<double>(im) + static_cast<double>(llm);
return BSON("" << result);
}
};
@@ -4896,7 +4937,7 @@ public:
add<Add::IntInt>();
add<Add::IntIntNoOverflow>();
add<Add::IntLong>();
- add<Add::IntLongOverflow>();
+ add<Add::IntLongOverflowToDouble>();
add<Add::IntDouble>();
add<Add::IntDate>();
add<Add::LongDouble>();
diff --git a/src/mongo/db/pipeline/value.cpp b/src/mongo/db/pipeline/value.cpp
index 8a37c51067c..37b3b88afae 100644
--- a/src/mongo/db/pipeline/value.cpp
+++ b/src/mongo/db/pipeline/value.cpp
@@ -971,8 +971,6 @@ BSONType Value::getWidestNumeric(BSONType lType, BSONType rType) {
return Undefined;
}
-// TODO: Add Decimal128 support to Value::integral()
-// SERVER-19735
bool Value::integral() const {
switch (getType()) {
case NumberInt:
@@ -984,6 +982,13 @@ bool Value::integral() const {
return (_storage.doubleValue <= numeric_limits<int>::max() &&
_storage.doubleValue >= numeric_limits<int>::min() &&
_storage.doubleValue == static_cast<int>(_storage.doubleValue));
+ case NumberDecimal: {
+ // If we are able to convert the decimal to an int32_t without an rounding errors,
+ // then it is integral.
+ uint32_t signalingFlags = Decimal128::kNoFlag;
+ (void)_storage.getDecimal().toInt(&signalingFlags);
+ return signalingFlags == Decimal128::kNoFlag;
+ }
default:
return false;
}
diff --git a/src/mongo/db/pipeline/value.h b/src/mongo/db/pipeline/value.h
index c6d4b90c0cd..5ac5c61180b 100644
--- a/src/mongo/db/pipeline/value.h
+++ b/src/mongo/db/pipeline/value.h
@@ -119,11 +119,9 @@ public:
}
/// true if type represents a number
- // TODO: Add _storage.type == NumberDecimal
- // SERVER-19735
bool numeric() const {
return _storage.type == NumberDouble || _storage.type == NumberLong ||
- _storage.type == NumberInt;
+ _storage.type == NumberInt || _storage.type == NumberDecimal;
}
/**