summaryrefslogtreecommitdiff
path: root/src/mongo/db/exec
diff options
context:
space:
mode:
authorBobby Morck <bobby.morck@mongodb.com>2021-09-29 11:04:22 -0400
committerEvergreen Agent <no-reply@evergreen.mongodb.com>2021-10-15 21:30:02 +0000
commitdad32a29132b9427a9d742f4f4d2ecf3bc3d830f (patch)
treeb9c91e702baa42615a35f4dec160381086def687 /src/mongo/db/exec
parenta9e2c89cae1be4235dee8321e0ff5511566772c1 (diff)
downloadmongo-dad32a29132b9427a9d742f4f4d2ecf3bc3d830f.tar.gz
SERVER-57552 Adding support for stdDevPop and stdDevSamp in SBE
Diffstat (limited to 'src/mongo/db/exec')
-rw-r--r--src/mongo/db/exec/sbe/expressions/expression.cpp5
-rw-r--r--src/mongo/db/exec/sbe/vm/arith.cpp82
-rw-r--r--src/mongo/db/exec/sbe/vm/vm.cpp51
-rw-r--r--src/mongo/db/exec/sbe/vm/vm.h35
4 files changed, 172 insertions, 1 deletions
diff --git a/src/mongo/db/exec/sbe/expressions/expression.cpp b/src/mongo/db/exec/sbe/expressions/expression.cpp
index 8e1a8785ec0..e1dd5d7b205 100644
--- a/src/mongo/db/exec/sbe/expressions/expression.cpp
+++ b/src/mongo/db/exec/sbe/expressions/expression.cpp
@@ -412,6 +412,11 @@ static stdx::unordered_map<std::string, BuiltinFn> kBuiltinFunctions = {
BuiltinFn{[](size_t n) { return n == 1; }, vm::Builtin::aggDoubleDoubleSum, true}},
{"doubleDoubleSumFinalize",
BuiltinFn{[](size_t n) { return n > 0; }, vm::Builtin::doubleDoubleSumFinalize, false}},
+ {"aggStdDev", BuiltinFn{[](size_t n) { return n == 1; }, vm::Builtin::aggStdDev, true}},
+ {"stdDevPopFinalize",
+ BuiltinFn{[](size_t n) { return n == 1; }, vm::Builtin::stdDevPopFinalize, false}},
+ {"stdDevSampFinalize",
+ BuiltinFn{[](size_t n) { return n == 1; }, vm::Builtin::stdDevSampFinalize, false}},
{"bitTestZero", BuiltinFn{[](size_t n) { return n == 2; }, vm::Builtin::bitTestZero, false}},
{"bitTestMask", BuiltinFn{[](size_t n) { return n == 2; }, vm::Builtin::bitTestMask, false}},
{"bitTestPosition",
diff --git a/src/mongo/db/exec/sbe/vm/arith.cpp b/src/mongo/db/exec/sbe/vm/arith.cpp
index ed9d0d00e76..6da30783aab 100644
--- a/src/mongo/db/exec/sbe/vm/arith.cpp
+++ b/src/mongo/db/exec/sbe/vm/arith.cpp
@@ -423,6 +423,12 @@ void addNonDecimal(TypeTags tag, Value val, DoubleDoubleSummation& nonDecimalTot
MONGO_UNREACHABLE_TASSERT(5755316);
}
}
+
+void setStdDevArray(value::Value count, value::Value mean, value::Value m2, Array* arr) {
+ arr->setAt(AggStdDevValueElems::kCount, value::TypeTags::NumberInt64, count);
+ arr->setAt(AggStdDevValueElems::kRunningMean, value::TypeTags::NumberDouble, mean);
+ arr->setAt(AggStdDevValueElems::kRunningM2, value::TypeTags::NumberDouble, m2);
+}
} // namespace
void ByteCode::aggDoubleDoubleSumImpl(value::Array* arr,
@@ -496,6 +502,82 @@ void ByteCode::aggDoubleDoubleSumImpl(value::Array* arr,
}
}
+void ByteCode::aggStdDevImpl(value::Array* arr, value::TypeTags rhsTag, value::Value rhsValue) {
+ if (!isNumber(rhsTag)) {
+ return;
+ }
+
+ auto [countTag, countVal] = arr->getAt(AggStdDevValueElems::kCount);
+ tassert(5755201, "The count must be of type NumberInt64", countTag == TypeTags::NumberInt64);
+
+ auto [meanTag, meanVal] = arr->getAt(AggStdDevValueElems::kRunningMean);
+ auto [m2Tag, m2Val] = arr->getAt(AggStdDevValueElems::kRunningM2);
+ tassert(5755202,
+ "The mean and m2 must be of type Double",
+ m2Tag == meanTag && meanTag == TypeTags::NumberDouble);
+
+ double inputDouble = 0.0;
+ // Within our query execution engine, $stdDevPop and $stdDevSamp do not maintain the precision
+ // of decimal types and converts all values to double. We do this here by converting
+ // NumberDecimal to Decimal128 and then extract a double value from it.
+ if (rhsTag == value::TypeTags::NumberDecimal) {
+ auto decimal = value::bitcastTo<Decimal128>(rhsValue);
+ inputDouble = decimal.toDouble();
+ } else {
+ inputDouble = numericCast<double>(rhsTag, rhsValue);
+ }
+ auto curVal = value::bitcastFrom<double>(inputDouble);
+
+ auto count = value::bitcastTo<int64_t>(countVal);
+ tassert(5755211,
+ "The total number of elements must be less than INT64_MAX",
+ ++count < std::numeric_limits<int64_t>::max());
+ auto newCountVal = value::bitcastFrom<int64_t>(count);
+
+ auto [deltaOwned, deltaTag, deltaVal] =
+ genericSub(value::TypeTags::NumberDouble, curVal, value::TypeTags::NumberDouble, meanVal);
+ auto [deltaDivCountOwned, deltaDivCountTag, deltaDivCountVal] =
+ genericDiv(deltaTag, deltaVal, value::TypeTags::NumberInt64, newCountVal);
+ auto [newMeanOwned, newMeanTag, newMeanVal] =
+ genericAdd(meanTag, meanVal, deltaDivCountTag, deltaDivCountVal);
+ auto [newDeltaOwned, newDeltaTag, newDeltaVal] =
+ genericSub(value::TypeTags::NumberDouble, curVal, newMeanTag, newMeanVal);
+ auto [deltaMultNewDeltaOwned, deltaMultNewDeltaTag, deltaMultNewDeltaVal] =
+ genericMul(deltaTag, deltaVal, newDeltaTag, newDeltaVal);
+ auto [newM2Owned, newM2Tag, newM2Val] =
+ genericAdd(m2Tag, m2Val, deltaMultNewDeltaTag, deltaMultNewDeltaVal);
+
+ return setStdDevArray(newCountVal, newMeanVal, newM2Val, arr);
+}
+
+std::tuple<bool, value::TypeTags, value::Value> ByteCode::aggStdDevFinalizeImpl(
+ value::Value fieldValue, bool isSamp) {
+ auto arr = value::getArrayView(fieldValue);
+
+ auto [countTag, countVal] = arr->getAt(AggStdDevValueElems::kCount);
+ tassert(5755207, "The count must be a NumberInt64", countTag == value::TypeTags::NumberInt64);
+
+ auto count = value::bitcastTo<int64_t>(countVal);
+
+ if (count == 0) {
+ return {true, value::TypeTags::Null, 0};
+ }
+
+ if (isSamp && count == 1) {
+ return {true, value::TypeTags::Null, 0};
+ }
+
+ auto [m2Tag, m2] = arr->getAt(AggStdDevValueElems::kRunningM2);
+ tassert(5755208,
+ "The m2 value must be of type NumberDouble",
+ m2Tag == value::TypeTags::NumberDouble);
+ auto m2Double = value::bitcastTo<double>(m2);
+ auto variance = isSamp ? (m2Double / (count - 1)) : (m2Double / count);
+ auto stdDev = sqrt(variance);
+
+ return {true, value::TypeTags::NumberDouble, value::bitcastFrom<double>(stdDev)};
+}
+
std::tuple<bool, value::TypeTags, value::Value> ByteCode::genericSub(value::TypeTags lhsTag,
value::Value lhsValue,
value::TypeTags rhsTag,
diff --git a/src/mongo/db/exec/sbe/vm/vm.cpp b/src/mongo/db/exec/sbe/vm/vm.cpp
index 60c0c802f66..a848cc7dd09 100644
--- a/src/mongo/db/exec/sbe/vm/vm.cpp
+++ b/src/mongo/db/exec/sbe/vm/vm.cpp
@@ -983,7 +983,7 @@ std::tuple<bool, value::TypeTags, value::Value> ByteCode::builtinDoubleDoubleSum
auto [sumTag, sum] = arr->getAt(AggSumValueElems::kNonDecimalTotalSum);
auto [addendTag, addend] = arr->getAt(AggSumValueElems::kNonDecimalTotalAddend);
tassert(5755323,
- "The sum and addend must be NumbetDouble",
+ "The sum and addend must be NumberDouble",
sumTag == addendTag && sumTag == value::TypeTags::NumberDouble);
// We're guaranteed to always have a valid nonDecimalTotal value.
@@ -1035,6 +1035,49 @@ std::tuple<bool, value::TypeTags, value::Value> ByteCode::builtinDoubleDoubleSum
}
}
+std::tuple<bool, value::TypeTags, value::Value> ByteCode::builtinAggStdDev(ArityType arity) {
+ auto [_, fieldTag, fieldValue] = getFromStack(1);
+ // Move the incoming accumulator state from the stack. Given that we are now the owner of the
+ // state we are free to do any in-place update as we see fit.
+ auto [accTag, accValue] = moveOwnedFromStack(0);
+ value::ValueGuard guard{accTag, accValue};
+
+ // Initialize the accumulator.
+ if (accTag == value::TypeTags::Nothing) {
+ auto [newAccTag, newAccValue] = value::makeNewArray();
+ value::ValueGuard newGuard{newAccTag, newAccValue};
+ auto arr = value::getArrayView(newAccValue);
+ arr->reserve(AggStdDevValueElems::kSizeOfArray);
+
+ // The order of the following three elements should match to 'AggStdDevValueElems'.
+ arr->push_back(value::TypeTags::NumberInt64, value::bitcastFrom<int64_t>(0));
+ arr->push_back(value::TypeTags::NumberDouble, value::bitcastFrom<double>(0.0));
+ arr->push_back(value::TypeTags::NumberDouble, value::bitcastFrom<double>(0.0));
+ aggStdDevImpl(arr, fieldTag, fieldValue);
+ newGuard.reset();
+ return {true, newAccTag, newAccValue};
+ }
+ tassert(5755210, "The result slot must be Array-typed", accTag == value::TypeTags::Array);
+
+ aggStdDevImpl(value::getArrayView(accValue), fieldTag, fieldValue);
+ guard.reset();
+ return {true, accTag, accValue};
+}
+
+std::tuple<bool, value::TypeTags, value::Value> ByteCode::builtinStdDevPopFinalize(
+ ArityType arity) {
+ auto [_, fieldTag, fieldValue] = getFromStack(0);
+
+ return aggStdDevFinalizeImpl(fieldValue, false /* isSamp */);
+}
+
+std::tuple<bool, value::TypeTags, value::Value> ByteCode::builtinStdDevSampFinalize(
+ ArityType arity) {
+ auto [_, fieldTag, fieldValue] = getFromStack(0);
+
+ return aggStdDevFinalizeImpl(fieldValue, true /* isSamp */);
+}
+
std::tuple<bool, value::TypeTags, value::Value> ByteCode::aggMin(value::TypeTags accTag,
value::Value accValue,
value::TypeTags fieldTag,
@@ -3779,6 +3822,12 @@ std::tuple<bool, value::TypeTags, value::Value> ByteCode::dispatchBuiltin(Builti
return builtinAggDoubleDoubleSum(arity);
case Builtin::doubleDoubleSumFinalize:
return builtinDoubleDoubleSumFinalize(arity);
+ case Builtin::aggStdDev:
+ return builtinAggStdDev(arity);
+ case Builtin::stdDevPopFinalize:
+ return builtinStdDevPopFinalize(arity);
+ case Builtin::stdDevSampFinalize:
+ return builtinStdDevSampFinalize(arity);
case Builtin::bitTestZero:
return builtinBitTestZero(arity);
case Builtin::bitTestMask:
diff --git a/src/mongo/db/exec/sbe/vm/vm.h b/src/mongo/db/exec/sbe/vm/vm.h
index da81a4fa533..e8e378ffec9 100644
--- a/src/mongo/db/exec/sbe/vm/vm.h
+++ b/src/mongo/db/exec/sbe/vm/vm.h
@@ -362,6 +362,9 @@ enum class Builtin : uint8_t {
doubleDoubleSum, // special double summation
aggDoubleDoubleSum,
doubleDoubleSumFinalize,
+ aggStdDev,
+ stdDevPopFinalize,
+ stdDevSampFinalize,
bitTestZero, // test bitwise mask & value is zero
bitTestMask, // test bitwise mask & value is mask
bitTestPosition, // test BinData with a bit position list
@@ -442,6 +445,28 @@ enum AggSumValueElems {
kMaxSizeOfArray
};
+/**
+ * This enum defines indices into an 'Array' that accumulates $stdDevPop and $stdDevSamp results.
+ *
+ * The array contains 3 elements:
+ * - The element at index `kCount` keeps track of the total number of values processd
+ * - The elements at index `kRunningMean` keeps track of the mean of all the values that have been
+ * processed.
+ * - The elements at index `kRunningM2` keeps track of running M2 value (defined within:
+ * https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_online_algorithm)
+ * for all the values that have been processed.
+ *
+ * See 'aggStdDevImpl()'/'aggStdDev()'/'stdDevPopFinalize() / stdDevSampFinalize()' for more
+ * details.
+ */
+enum AggStdDevValueElems {
+ kCount,
+ kRunningMean,
+ kRunningM2,
+ // This is actually not an index but represents the number of elements stored
+ kSizeOfArray
+};
+
using SmallArityType = uint8_t;
using ArityType = uint32_t;
@@ -806,6 +831,13 @@ private:
void aggDoubleDoubleSumImpl(value::Array* arr, value::TypeTags rhsTag, value::Value rhsValue);
+ // This is an implementation of the following algorithm:
+ // https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_online_algorithm
+ void aggStdDevImpl(value::Array* arr, value::TypeTags rhsTag, value::Value rhsValue);
+
+ std::tuple<bool, value::TypeTags, value::Value> aggStdDevFinalizeImpl(value::Value fieldValue,
+ bool isSamp);
+
std::tuple<bool, value::TypeTags, value::Value> aggMin(value::TypeTags accTag,
value::Value accValue,
value::TypeTags fieldTag,
@@ -925,6 +957,9 @@ private:
std::tuple<bool, value::TypeTags, value::Value> builtinDoubleDoubleSum(ArityType arity);
std::tuple<bool, value::TypeTags, value::Value> builtinAggDoubleDoubleSum(ArityType arity);
std::tuple<bool, value::TypeTags, value::Value> builtinDoubleDoubleSumFinalize(ArityType arity);
+ std::tuple<bool, value::TypeTags, value::Value> builtinAggStdDev(ArityType arity);
+ std::tuple<bool, value::TypeTags, value::Value> builtinStdDevPopFinalize(ArityType arity);
+ std::tuple<bool, value::TypeTags, value::Value> builtinStdDevSampFinalize(ArityType arity);
std::tuple<bool, value::TypeTags, value::Value> builtinBitTestZero(ArityType arity);
std::tuple<bool, value::TypeTags, value::Value> builtinBitTestMask(ArityType arity);
std::tuple<bool, value::TypeTags, value::Value> builtinBitTestPosition(ArityType arity);