summaryrefslogtreecommitdiff
path: root/src/mongo/db/pipeline/expression.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/mongo/db/pipeline/expression.cpp')
-rw-r--r--src/mongo/db/pipeline/expression.cpp252
1 files changed, 193 insertions, 59 deletions
diff --git a/src/mongo/db/pipeline/expression.cpp b/src/mongo/db/pipeline/expression.cpp
index 9c68924d476..1c976df8f42 100644
--- a/src/mongo/db/pipeline/expression.cpp
+++ b/src/mongo/db/pipeline/expression.cpp
@@ -41,8 +41,10 @@
#include "mongo/db/pipeline/expression_context.h"
#include "mongo/db/pipeline/value.h"
#include "mongo/platform/bits.h"
+#include "mongo/platform/decimal128.h"
#include "mongo/util/mongoutils/str.h"
#include "mongo/util/string_map.h"
+#include "mongo/util/summation.h"
namespace mongo {
using Parser = Expression::Parser;
@@ -414,6 +416,8 @@ Value ExpressionAbs::evaluateNumericArg(const Value& numericArg) const {
BSONType type = numericArg.getType();
if (type == NumberDouble) {
return Value(std::abs(numericArg.getDouble()));
+ } else if (type == NumberDecimal) {
+ return Value(numericArg.getDecimal().toAbs());
} else {
long long num = numericArg.getLong();
uassert(28680,
@@ -432,14 +436,12 @@ const char* ExpressionAbs::getOpName() const {
/* ------------------------- ExpressionAdd ----------------------------- */
Value ExpressionAdd::evaluateInternal(Variables* vars) const {
- /*
- We'll try to return the narrowest possible result value. To do that
- without creating intermediate Values, do the arithmetic for double
- and integral types in parallel, tracking the current narrowest
- type.
- */
- double doubleTotal = 0;
- long long longTotal = 0;
+ // We'll try to return the narrowest possible result value while avoiding overflow, loss
+ // of precision due to intermediate rounding or implicit use of decimal types. To do that,
+ // compute a compensated sum for non-decimal values and a separate decimal sum for decimal
+ // values, and track the current narrowest type.
+ DoubleDoubleSummation nonDecimalTotal;
+ Decimal128 decimalTotal;
BSONType totalType = NumberInt;
bool haveDate = false;
@@ -447,40 +449,64 @@ Value ExpressionAdd::evaluateInternal(Variables* vars) const {
for (size_t i = 0; i < n; ++i) {
Value val = vpOperand[i]->evaluateInternal(vars);
- if (val.numeric()) {
- totalType = Value::getWidestNumeric(totalType, val.getType());
-
- doubleTotal += val.coerceToDouble();
- longTotal += val.coerceToLong();
- } else if (val.getType() == Date) {
- uassert(16612, "only one date allowed in an $add expression", !haveDate);
- haveDate = true;
-
- // We don't manipulate totalType here.
-
- longTotal += val.getDate();
- doubleTotal += val.getDate();
- } else if (val.nullish()) {
- return Value(BSONNULL);
- } else {
- uasserted(16554,
- str::stream() << "$add only supports numeric or date types, not "
- << typeName(val.getType()));
+ switch (val.getType()) {
+ case NumberDecimal:
+ decimalTotal = decimalTotal.add(val.getDecimal());
+ totalType = NumberDecimal;
+ break;
+ case NumberDouble:
+ nonDecimalTotal.addDouble(val.getDouble());
+ if (totalType != NumberDecimal)
+ totalType = NumberDouble;
+ break;
+ case NumberLong:
+ nonDecimalTotal.addLong(val.getLong());
+ if (totalType == NumberInt)
+ totalType = NumberLong;
+ break;
+ case NumberInt:
+ nonDecimalTotal.addDouble(val.getInt());
+ break;
+ case Date:
+ uassert(16612, "only one date allowed in an $add expression", !haveDate);
+ haveDate = true;
+ nonDecimalTotal.addLong(val.getDate());
+ break;
+ default:
+ uassert(16554,
+ str::stream() << "$add only supports numeric or date types, not "
+ << typeName(val.getType()),
+ val.nullish());
+ return Value(BSONNULL);
}
}
if (haveDate) {
- if (totalType == NumberDouble)
- longTotal = static_cast<long long>(doubleTotal);
+ int64_t longTotal;
+ if (totalType == NumberDecimal) {
+ longTotal = decimalTotal.add(nonDecimalTotal.getDecimal()).toLong();
+ } else {
+ uassert(ErrorCodes::Overflow, "date overflow in $add", nonDecimalTotal.fitsLong());
+ longTotal = nonDecimalTotal.getLong();
+ }
return Value(Date_t::fromMillisSinceEpoch(longTotal));
- } else if (totalType == NumberLong) {
- return Value(longTotal);
- } else if (totalType == NumberDouble) {
- return Value(doubleTotal);
- } else if (totalType == NumberInt) {
- return Value::createIntOrLong(longTotal);
- } else {
- massert(16417, "$add resulted in a non-numeric type", false);
+ }
+ switch (totalType) {
+ case NumberDecimal:
+ return Value(decimalTotal.add(nonDecimalTotal.getDecimal()));
+ case NumberLong:
+ dassert(nonDecimalTotal.isInteger());
+ if (nonDecimalTotal.fitsLong())
+ return Value(nonDecimalTotal.getLong());
+ // Fallthrough.
+ case NumberInt:
+ if (nonDecimalTotal.fitsLong())
+ return Value::createIntOrLong(nonDecimalTotal.getLong());
+ // Fallthrough.
+ case NumberDouble:
+ return Value(nonDecimalTotal.getDouble());
+ default:
+ massert(16417, "$add resulted in a non-numeric type", false);
}
}
@@ -677,8 +703,16 @@ const char* ExpressionArrayElemAt::getOpName() const {
Value ExpressionCeil::evaluateNumericArg(const Value& numericArg) const {
// There's no point in taking the ceiling of integers or longs, it will have no effect.
- return numericArg.getType() == NumberDouble ? Value(std::ceil(numericArg.getDouble()))
- : numericArg;
+ switch (numericArg.getType()) {
+ case NumberDouble:
+ return Value(std::ceil(numericArg.getDouble()));
+ case NumberDecimal:
+ // Round toward the nearest decimal with a zero exponent in the positive direction.
+ return Value(numericArg.getDecimal().quantize(Decimal128::kNormalizedZero,
+ Decimal128::kRoundTowardPositive));
+ default:
+ return numericArg;
+ }
}
REGISTER_EXPRESSION(ceil, ExpressionCeil::parse);
@@ -1203,10 +1237,20 @@ Value ExpressionDivide::evaluateInternal(Variables* vars) const {
Value lhs = vpOperand[0]->evaluateInternal(vars);
Value rhs = vpOperand[1]->evaluateInternal(vars);
+ auto assertNonZero = [](bool nonZero) { uassert(16608, "can't $divide by zero", nonZero); };
+
if (lhs.numeric() && rhs.numeric()) {
+ // If, and only if, either side is decimal, return decimal.
+ if (lhs.getType() == NumberDecimal || rhs.getType() == NumberDecimal) {
+ Decimal128 numer = lhs.coerceToDecimal();
+ Decimal128 denom = rhs.coerceToDecimal();
+ assertNonZero(!denom.isZero());
+ return Value(numer.divide(denom));
+ }
+
double numer = lhs.coerceToDouble();
double denom = rhs.coerceToDouble();
- uassert(16608, "can't $divide by zero", denom != 0);
+ assertNonZero(denom != 0.0);
return Value(numer / denom);
} else if (lhs.nullish() || rhs.nullish()) {
@@ -1228,7 +1272,10 @@ const char* ExpressionDivide::getOpName() const {
/* ----------------------- ExpressionExp ---------------------------- */
Value ExpressionExp::evaluateNumericArg(const Value& numericArg) const {
- // exp() always returns a double since e is a double.
+ // $exp always returns either a double or a decimal number, as e is irrational.
+ if (numericArg.getType() == NumberDecimal)
+ return Value(numericArg.coerceToDecimal().exponential());
+
return Value(exp(numericArg.coerceToDouble()));
}
@@ -1740,8 +1787,16 @@ void ExpressionFilter::addDependencies(DepsTracker* deps, vector<string>* path)
Value ExpressionFloor::evaluateNumericArg(const Value& numericArg) const {
// There's no point in taking the floor of integers or longs, it will have no effect.
- return numericArg.getType() == NumberDouble ? Value(std::floor(numericArg.getDouble()))
- : numericArg;
+ switch (numericArg.getType()) {
+ case NumberDouble:
+ return Value(std::floor(numericArg.getDouble()));
+ case NumberDecimal:
+ // Round toward the nearest decimal with a zero exponent in the negative direction.
+ return Value(numericArg.getDecimal().quantize(Decimal128::kNormalizedZero,
+ Decimal128::kRoundTowardNegative));
+ default:
+ return numericArg;
+ }
}
REGISTER_EXPRESSION(floor, ExpressionFloor::parse);
@@ -2028,10 +2083,20 @@ Value ExpressionMod::evaluateInternal(Variables* vars) const {
BSONType rightType = rhs.getType();
if (lhs.numeric() && rhs.numeric()) {
+ auto assertNonZero = [](bool isZero) { uassert(16610, "can't $mod by zero", !isZero); };
+
+ // If either side is decimal, perform the operation in decimal.
+ if (leftType == NumberDecimal || rightType == NumberDecimal) {
+ Decimal128 left = lhs.coerceToDecimal();
+ Decimal128 right = rhs.coerceToDecimal();
+ assertNonZero(right.isZero());
+
+ return Value(left.modulo(right));
+ }
+
// ensure we aren't modding by 0
double right = rhs.coerceToDouble();
-
- uassert(16610, "can't $mod by 0", right != 0);
+ assertNonZero(right == 0);
if (leftType == NumberDouble || (rightType == NumberDouble && !rhs.integral())) {
// Need to do fmod. Integer-valued double case is handled below.
@@ -2088,6 +2153,8 @@ Value ExpressionMultiply::evaluateInternal(Variables* vars) const {
*/
double doubleProduct = 1;
long long longProduct = 1;
+ Decimal128 decimalProduct; // This will be initialized on encountering the first decimal.
+
BSONType productType = NumberInt;
const size_t n = vpOperand.size();
@@ -2095,10 +2162,23 @@ Value ExpressionMultiply::evaluateInternal(Variables* vars) const {
Value val = vpOperand[i]->evaluateInternal(vars);
if (val.numeric()) {
+ BSONType oldProductType = productType;
productType = Value::getWidestNumeric(productType, val.getType());
-
- doubleProduct *= val.coerceToDouble();
- longProduct *= val.coerceToLong();
+ if (productType == NumberDecimal) {
+ // On finding the first decimal, convert the partial product to decimal.
+ if (oldProductType != NumberDecimal) {
+ decimalProduct = oldProductType == NumberDouble
+ ? Decimal128(doubleProduct, Decimal128::kRoundTo15Digits)
+ : Decimal128(static_cast<int64_t>(longProduct));
+ }
+ decimalProduct = decimalProduct.multiply(val.coerceToDecimal());
+ } else {
+ doubleProduct *= val.coerceToDouble();
+ if (mongoSignedMultiplyOverflow64(longProduct, val.coerceToLong(), &longProduct)) {
+ // The 'longProduct' would have overflowed, so we're abandoning it.
+ productType = NumberDouble;
+ }
+ }
} else if (val.nullish()) {
return Value(BSONNULL);
} else {
@@ -2114,6 +2194,8 @@ Value ExpressionMultiply::evaluateInternal(Variables* vars) const {
return Value(longProduct);
else if (productType == NumberInt)
return Value::createIntOrLong(longProduct);
+ else if (productType == NumberDecimal)
+ return Value(decimalProduct);
else
massert(16418, "$multiply resulted in a non-numeric type", false);
}
@@ -2390,6 +2472,12 @@ const char* ExpressionIndexOfCP::getOpName() const {
/* ----------------------- ExpressionLn ---------------------------- */
Value ExpressionLn::evaluateNumericArg(const Value& numericArg) const {
+ if (numericArg.getType() == NumberDecimal) {
+ Decimal128 argDecimal = numericArg.getDecimal();
+ if (argDecimal.isGreater(Decimal128::kNormalizedZero))
+ return Value(argDecimal.logarithm());
+ // Fall through for error case.
+ }
double argDouble = numericArg.coerceToDouble();
uassert(28766,
str::stream() << "$ln's argument must be a positive number, but is " << argDouble,
@@ -2417,6 +2505,18 @@ Value ExpressionLog::evaluateInternal(Variables* vars) const {
str::stream() << "$log's base must be numeric, not " << typeName(baseVal.getType()),
baseVal.numeric());
+ if (argVal.getType() == NumberDecimal || baseVal.getType() == NumberDecimal) {
+ Decimal128 argDecimal = argVal.coerceToDecimal();
+ Decimal128 baseDecimal = baseVal.coerceToDecimal();
+
+ if (argDecimal.isGreater(Decimal128::kNormalizedZero) &&
+ baseDecimal.isNotEqual(Decimal128(1)) &&
+ baseDecimal.isGreater(Decimal128::kNormalizedZero)) {
+ return Value(argDecimal.logarithm(baseDecimal));
+ }
+ // Fall through for error cases.
+ }
+
double argDouble = argVal.coerceToDouble();
double baseDouble = baseVal.coerceToDouble();
uassert(28758,
@@ -2437,6 +2537,13 @@ const char* ExpressionLog::getOpName() const {
/* ----------------------- ExpressionLog10 ---------------------------- */
Value ExpressionLog10::evaluateNumericArg(const Value& numericArg) const {
+ if (numericArg.getType() == NumberDecimal) {
+ Decimal128 argDecimal = numericArg.getDecimal();
+ if (argDecimal.isGreater(Decimal128::kNormalizedZero))
+ return Value(argDecimal.logarithm(Decimal128(10)));
+ // Fall through for error case.
+ }
+
double argDouble = numericArg.coerceToDouble();
uassert(28761,
str::stream() << "$log10's argument must be a positive number, but is " << argDouble,
@@ -2672,15 +2779,24 @@ Value ExpressionPow::evaluateInternal(Variables* vars) const {
str::stream() << "$pow's exponent must be numeric, not " << typeName(expType),
expVal.numeric());
+ auto checkNonZeroAndNeg = [](bool isZeroAndNeg) {
+ uassert(28764, "$pow cannot take a base of 0 and a negative exponent", !isZeroAndNeg);
+ };
+
+ // If either argument is decimal, return a decimal.
+ if (baseType == NumberDecimal || expType == NumberDecimal) {
+ Decimal128 baseDecimal = baseVal.coerceToDecimal();
+ Decimal128 expDecimal = expVal.coerceToDecimal();
+ checkNonZeroAndNeg(baseDecimal.isZero() && expDecimal.isNegative());
+ return Value(baseDecimal.power(expDecimal));
+ }
+
// pow() will cast args to doubles.
double baseDouble = baseVal.coerceToDouble();
double expDouble = expVal.coerceToDouble();
+ checkNonZeroAndNeg(baseDouble == 0 && expDouble < 0);
- uassert(28764,
- "$pow cannot take a base of 0 and a negative exponent",
- !(baseDouble == 0 && expDouble < 0));
-
- // If either number is a double, return a double.
+ // If either argument is a double, return a double.
if (baseType == NumberDouble || expType == NumberDouble) {
return Value(std::pow(baseDouble, expDouble));
}
@@ -3414,10 +3530,17 @@ const char* ExpressionSplit::getOpName() const {
/* ----------------------- ExpressionSqrt ---------------------------- */
Value ExpressionSqrt::evaluateNumericArg(const Value& numericArg) const {
+ auto checkArg = [](bool nonNegative) {
+ uassert(28714, "$sqrt's argument must be greater than or equal to 0", nonNegative);
+ };
+
+ if (numericArg.getType() == NumberDecimal) {
+ Decimal128 argDec = numericArg.getDecimal();
+ checkArg(!argDec.isLess(Decimal128::kNormalizedZero)); // NaN returns Nan without error
+ return Value(argDec.squareRoot());
+ }
double argDouble = numericArg.coerceToDouble();
- uassert(28714,
- "$sqrt's argument must be greater than or equal to 0",
- argDouble >= 0 || std::isnan(argDouble));
+ checkArg(!(argDouble < 0)); // NaN returns Nan without error
return Value(sqrt(argDouble));
}
@@ -3638,7 +3761,11 @@ Value ExpressionSubtract::evaluateInternal(Variables* vars) const {
BSONType diffType = Value::getWidestNumeric(rhs.getType(), lhs.getType());
- if (diffType == NumberDouble) {
+ if (diffType == NumberDecimal) {
+ Decimal128 right = rhs.coerceToDecimal();
+ Decimal128 left = lhs.coerceToDecimal();
+ return Value(left.subtract(right));
+ } else if (diffType == NumberDouble) {
double right = rhs.coerceToDouble();
double left = lhs.coerceToDouble();
return Value(left - right);
@@ -3838,8 +3965,15 @@ const char* ExpressionToUpper::getOpName() const {
Value ExpressionTrunc::evaluateNumericArg(const Value& numericArg) const {
// There's no point in truncating integers or longs, it will have no effect.
- return numericArg.getType() == NumberDouble ? Value(std::trunc(numericArg.getDouble()))
- : numericArg;
+ switch (numericArg.getType()) {
+ case NumberDecimal:
+ return Value(numericArg.getDecimal().quantize(Decimal128::kNormalizedZero,
+ Decimal128::kRoundTowardZero));
+ case NumberDouble:
+ return Value(std::trunc(numericArg.getDouble()));
+ default:
+ return numericArg;
+ }
}
REGISTER_EXPRESSION(trunc, ExpressionTrunc::parse);