diff options
Diffstat (limited to 'src/mongo/db/exec')
-rw-r--r-- | src/mongo/db/exec/sbe/expressions/expression.cpp | 5 | ||||
-rw-r--r-- | src/mongo/db/exec/sbe/vm/arith.cpp | 82 | ||||
-rw-r--r-- | src/mongo/db/exec/sbe/vm/vm.cpp | 51 | ||||
-rw-r--r-- | src/mongo/db/exec/sbe/vm/vm.h | 35 |
4 files changed, 172 insertions, 1 deletions
diff --git a/src/mongo/db/exec/sbe/expressions/expression.cpp b/src/mongo/db/exec/sbe/expressions/expression.cpp index 8e1a8785ec0..e1dd5d7b205 100644 --- a/src/mongo/db/exec/sbe/expressions/expression.cpp +++ b/src/mongo/db/exec/sbe/expressions/expression.cpp @@ -412,6 +412,11 @@ static stdx::unordered_map<std::string, BuiltinFn> kBuiltinFunctions = { BuiltinFn{[](size_t n) { return n == 1; }, vm::Builtin::aggDoubleDoubleSum, true}}, {"doubleDoubleSumFinalize", BuiltinFn{[](size_t n) { return n > 0; }, vm::Builtin::doubleDoubleSumFinalize, false}}, + {"aggStdDev", BuiltinFn{[](size_t n) { return n == 1; }, vm::Builtin::aggStdDev, true}}, + {"stdDevPopFinalize", + BuiltinFn{[](size_t n) { return n == 1; }, vm::Builtin::stdDevPopFinalize, false}}, + {"stdDevSampFinalize", + BuiltinFn{[](size_t n) { return n == 1; }, vm::Builtin::stdDevSampFinalize, false}}, {"bitTestZero", BuiltinFn{[](size_t n) { return n == 2; }, vm::Builtin::bitTestZero, false}}, {"bitTestMask", BuiltinFn{[](size_t n) { return n == 2; }, vm::Builtin::bitTestMask, false}}, {"bitTestPosition", diff --git a/src/mongo/db/exec/sbe/vm/arith.cpp b/src/mongo/db/exec/sbe/vm/arith.cpp index ed9d0d00e76..6da30783aab 100644 --- a/src/mongo/db/exec/sbe/vm/arith.cpp +++ b/src/mongo/db/exec/sbe/vm/arith.cpp @@ -423,6 +423,12 @@ void addNonDecimal(TypeTags tag, Value val, DoubleDoubleSummation& nonDecimalTot MONGO_UNREACHABLE_TASSERT(5755316); } } + +void setStdDevArray(value::Value count, value::Value mean, value::Value m2, Array* arr) { + arr->setAt(AggStdDevValueElems::kCount, value::TypeTags::NumberInt64, count); + arr->setAt(AggStdDevValueElems::kRunningMean, value::TypeTags::NumberDouble, mean); + arr->setAt(AggStdDevValueElems::kRunningM2, value::TypeTags::NumberDouble, m2); +} } // namespace void ByteCode::aggDoubleDoubleSumImpl(value::Array* arr, @@ -496,6 +502,82 @@ void ByteCode::aggDoubleDoubleSumImpl(value::Array* arr, } } +void ByteCode::aggStdDevImpl(value::Array* arr, value::TypeTags rhsTag, value::Value rhsValue) { + if (!isNumber(rhsTag)) { + return; + } + + auto [countTag, countVal] = arr->getAt(AggStdDevValueElems::kCount); + tassert(5755201, "The count must be of type NumberInt64", countTag == TypeTags::NumberInt64); + + auto [meanTag, meanVal] = arr->getAt(AggStdDevValueElems::kRunningMean); + auto [m2Tag, m2Val] = arr->getAt(AggStdDevValueElems::kRunningM2); + tassert(5755202, + "The mean and m2 must be of type Double", + m2Tag == meanTag && meanTag == TypeTags::NumberDouble); + + double inputDouble = 0.0; + // Within our query execution engine, $stdDevPop and $stdDevSamp do not maintain the precision + // of decimal types and converts all values to double. We do this here by converting + // NumberDecimal to Decimal128 and then extract a double value from it. + if (rhsTag == value::TypeTags::NumberDecimal) { + auto decimal = value::bitcastTo<Decimal128>(rhsValue); + inputDouble = decimal.toDouble(); + } else { + inputDouble = numericCast<double>(rhsTag, rhsValue); + } + auto curVal = value::bitcastFrom<double>(inputDouble); + + auto count = value::bitcastTo<int64_t>(countVal); + tassert(5755211, + "The total number of elements must be less than INT64_MAX", + ++count < std::numeric_limits<int64_t>::max()); + auto newCountVal = value::bitcastFrom<int64_t>(count); + + auto [deltaOwned, deltaTag, deltaVal] = + genericSub(value::TypeTags::NumberDouble, curVal, value::TypeTags::NumberDouble, meanVal); + auto [deltaDivCountOwned, deltaDivCountTag, deltaDivCountVal] = + genericDiv(deltaTag, deltaVal, value::TypeTags::NumberInt64, newCountVal); + auto [newMeanOwned, newMeanTag, newMeanVal] = + genericAdd(meanTag, meanVal, deltaDivCountTag, deltaDivCountVal); + auto [newDeltaOwned, newDeltaTag, newDeltaVal] = + genericSub(value::TypeTags::NumberDouble, curVal, newMeanTag, newMeanVal); + auto [deltaMultNewDeltaOwned, deltaMultNewDeltaTag, deltaMultNewDeltaVal] = + genericMul(deltaTag, deltaVal, newDeltaTag, newDeltaVal); + auto [newM2Owned, newM2Tag, newM2Val] = + genericAdd(m2Tag, m2Val, deltaMultNewDeltaTag, deltaMultNewDeltaVal); + + return setStdDevArray(newCountVal, newMeanVal, newM2Val, arr); +} + +std::tuple<bool, value::TypeTags, value::Value> ByteCode::aggStdDevFinalizeImpl( + value::Value fieldValue, bool isSamp) { + auto arr = value::getArrayView(fieldValue); + + auto [countTag, countVal] = arr->getAt(AggStdDevValueElems::kCount); + tassert(5755207, "The count must be a NumberInt64", countTag == value::TypeTags::NumberInt64); + + auto count = value::bitcastTo<int64_t>(countVal); + + if (count == 0) { + return {true, value::TypeTags::Null, 0}; + } + + if (isSamp && count == 1) { + return {true, value::TypeTags::Null, 0}; + } + + auto [m2Tag, m2] = arr->getAt(AggStdDevValueElems::kRunningM2); + tassert(5755208, + "The m2 value must be of type NumberDouble", + m2Tag == value::TypeTags::NumberDouble); + auto m2Double = value::bitcastTo<double>(m2); + auto variance = isSamp ? (m2Double / (count - 1)) : (m2Double / count); + auto stdDev = sqrt(variance); + + return {true, value::TypeTags::NumberDouble, value::bitcastFrom<double>(stdDev)}; +} + std::tuple<bool, value::TypeTags, value::Value> ByteCode::genericSub(value::TypeTags lhsTag, value::Value lhsValue, value::TypeTags rhsTag, diff --git a/src/mongo/db/exec/sbe/vm/vm.cpp b/src/mongo/db/exec/sbe/vm/vm.cpp index 60c0c802f66..a848cc7dd09 100644 --- a/src/mongo/db/exec/sbe/vm/vm.cpp +++ b/src/mongo/db/exec/sbe/vm/vm.cpp @@ -983,7 +983,7 @@ std::tuple<bool, value::TypeTags, value::Value> ByteCode::builtinDoubleDoubleSum auto [sumTag, sum] = arr->getAt(AggSumValueElems::kNonDecimalTotalSum); auto [addendTag, addend] = arr->getAt(AggSumValueElems::kNonDecimalTotalAddend); tassert(5755323, - "The sum and addend must be NumbetDouble", + "The sum and addend must be NumberDouble", sumTag == addendTag && sumTag == value::TypeTags::NumberDouble); // We're guaranteed to always have a valid nonDecimalTotal value. @@ -1035,6 +1035,49 @@ std::tuple<bool, value::TypeTags, value::Value> ByteCode::builtinDoubleDoubleSum } } +std::tuple<bool, value::TypeTags, value::Value> ByteCode::builtinAggStdDev(ArityType arity) { + auto [_, fieldTag, fieldValue] = getFromStack(1); + // Move the incoming accumulator state from the stack. Given that we are now the owner of the + // state we are free to do any in-place update as we see fit. + auto [accTag, accValue] = moveOwnedFromStack(0); + value::ValueGuard guard{accTag, accValue}; + + // Initialize the accumulator. + if (accTag == value::TypeTags::Nothing) { + auto [newAccTag, newAccValue] = value::makeNewArray(); + value::ValueGuard newGuard{newAccTag, newAccValue}; + auto arr = value::getArrayView(newAccValue); + arr->reserve(AggStdDevValueElems::kSizeOfArray); + + // The order of the following three elements should match to 'AggStdDevValueElems'. + arr->push_back(value::TypeTags::NumberInt64, value::bitcastFrom<int64_t>(0)); + arr->push_back(value::TypeTags::NumberDouble, value::bitcastFrom<double>(0.0)); + arr->push_back(value::TypeTags::NumberDouble, value::bitcastFrom<double>(0.0)); + aggStdDevImpl(arr, fieldTag, fieldValue); + newGuard.reset(); + return {true, newAccTag, newAccValue}; + } + tassert(5755210, "The result slot must be Array-typed", accTag == value::TypeTags::Array); + + aggStdDevImpl(value::getArrayView(accValue), fieldTag, fieldValue); + guard.reset(); + return {true, accTag, accValue}; +} + +std::tuple<bool, value::TypeTags, value::Value> ByteCode::builtinStdDevPopFinalize( + ArityType arity) { + auto [_, fieldTag, fieldValue] = getFromStack(0); + + return aggStdDevFinalizeImpl(fieldValue, false /* isSamp */); +} + +std::tuple<bool, value::TypeTags, value::Value> ByteCode::builtinStdDevSampFinalize( + ArityType arity) { + auto [_, fieldTag, fieldValue] = getFromStack(0); + + return aggStdDevFinalizeImpl(fieldValue, true /* isSamp */); +} + std::tuple<bool, value::TypeTags, value::Value> ByteCode::aggMin(value::TypeTags accTag, value::Value accValue, value::TypeTags fieldTag, @@ -3779,6 +3822,12 @@ std::tuple<bool, value::TypeTags, value::Value> ByteCode::dispatchBuiltin(Builti return builtinAggDoubleDoubleSum(arity); case Builtin::doubleDoubleSumFinalize: return builtinDoubleDoubleSumFinalize(arity); + case Builtin::aggStdDev: + return builtinAggStdDev(arity); + case Builtin::stdDevPopFinalize: + return builtinStdDevPopFinalize(arity); + case Builtin::stdDevSampFinalize: + return builtinStdDevSampFinalize(arity); case Builtin::bitTestZero: return builtinBitTestZero(arity); case Builtin::bitTestMask: diff --git a/src/mongo/db/exec/sbe/vm/vm.h b/src/mongo/db/exec/sbe/vm/vm.h index da81a4fa533..e8e378ffec9 100644 --- a/src/mongo/db/exec/sbe/vm/vm.h +++ b/src/mongo/db/exec/sbe/vm/vm.h @@ -362,6 +362,9 @@ enum class Builtin : uint8_t { doubleDoubleSum, // special double summation aggDoubleDoubleSum, doubleDoubleSumFinalize, + aggStdDev, + stdDevPopFinalize, + stdDevSampFinalize, bitTestZero, // test bitwise mask & value is zero bitTestMask, // test bitwise mask & value is mask bitTestPosition, // test BinData with a bit position list @@ -442,6 +445,28 @@ enum AggSumValueElems { kMaxSizeOfArray }; +/** + * This enum defines indices into an 'Array' that accumulates $stdDevPop and $stdDevSamp results. + * + * The array contains 3 elements: + * - The element at index `kCount` keeps track of the total number of values processd + * - The elements at index `kRunningMean` keeps track of the mean of all the values that have been + * processed. + * - The elements at index `kRunningM2` keeps track of running M2 value (defined within: + * https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_online_algorithm) + * for all the values that have been processed. + * + * See 'aggStdDevImpl()'/'aggStdDev()'/'stdDevPopFinalize() / stdDevSampFinalize()' for more + * details. + */ +enum AggStdDevValueElems { + kCount, + kRunningMean, + kRunningM2, + // This is actually not an index but represents the number of elements stored + kSizeOfArray +}; + using SmallArityType = uint8_t; using ArityType = uint32_t; @@ -806,6 +831,13 @@ private: void aggDoubleDoubleSumImpl(value::Array* arr, value::TypeTags rhsTag, value::Value rhsValue); + // This is an implementation of the following algorithm: + // https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_online_algorithm + void aggStdDevImpl(value::Array* arr, value::TypeTags rhsTag, value::Value rhsValue); + + std::tuple<bool, value::TypeTags, value::Value> aggStdDevFinalizeImpl(value::Value fieldValue, + bool isSamp); + std::tuple<bool, value::TypeTags, value::Value> aggMin(value::TypeTags accTag, value::Value accValue, value::TypeTags fieldTag, @@ -925,6 +957,9 @@ private: std::tuple<bool, value::TypeTags, value::Value> builtinDoubleDoubleSum(ArityType arity); std::tuple<bool, value::TypeTags, value::Value> builtinAggDoubleDoubleSum(ArityType arity); std::tuple<bool, value::TypeTags, value::Value> builtinDoubleDoubleSumFinalize(ArityType arity); + std::tuple<bool, value::TypeTags, value::Value> builtinAggStdDev(ArityType arity); + std::tuple<bool, value::TypeTags, value::Value> builtinStdDevPopFinalize(ArityType arity); + std::tuple<bool, value::TypeTags, value::Value> builtinStdDevSampFinalize(ArityType arity); std::tuple<bool, value::TypeTags, value::Value> builtinBitTestZero(ArityType arity); std::tuple<bool, value::TypeTags, value::Value> builtinBitTestMask(ArityType arity); std::tuple<bool, value::TypeTags, value::Value> builtinBitTestPosition(ArityType arity); |