diff options
author | Ruoxin Xu <ruoxin.xu@mongodb.com> | 2021-03-24 13:42:10 +0000 |
---|---|---|
committer | Evergreen Agent <no-reply@evergreen.mongodb.com> | 2021-04-19 12:54:08 +0000 |
commit | c2c69206f2cc5460c7688c6ae331e772ac69fe5c (patch) | |
tree | b9c80db96e796f5fa596cff42aeaaec163fc0bb4 /src/mongo/db/pipeline/window_function | |
parent | 27b3c5b44644f5ba683c9ffe56f588f6382711a0 (diff) | |
download | mongo-c2c69206f2cc5460c7688c6ae331e772ac69fe5c.tar.gz |
SERVER-54241 Implement removable $covariance function
Diffstat (limited to 'src/mongo/db/pipeline/window_function')
6 files changed, 650 insertions, 1 deletions
diff --git a/src/mongo/db/pipeline/window_function/window_function.h b/src/mongo/db/pipeline/window_function/window_function.h index 7f22e20c1f3..6a496b96675 100644 --- a/src/mongo/db/pipeline/window_function/window_function.h +++ b/src/mongo/db/pipeline/window_function/window_function.h @@ -29,7 +29,6 @@ #pragma once -#include "mongo/db/pipeline/accumulator.h" #include "mongo/db/pipeline/document_source.h" #include "mongo/db/pipeline/expression.h" diff --git a/src/mongo/db/pipeline/window_function/window_function_covariance.cpp b/src/mongo/db/pipeline/window_function/window_function_covariance.cpp new file mode 100644 index 00000000000..cab8fdb6193 --- /dev/null +++ b/src/mongo/db/pipeline/window_function/window_function_covariance.cpp @@ -0,0 +1,164 @@ +/** + * Copyright (C) 2021-present MongoDB, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the Server Side Public License, version 1, + * as published by MongoDB, Inc. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Server Side Public License for more details. + * + * You should have received a copy of the Server Side Public License + * along with this program. If not, see + * <http://www.mongodb.com/licensing/server-side-public-license>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the Server Side Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#include "mongo/db/pipeline/window_function/window_function_covariance.h" +#include "mongo/db/pipeline/window_function/window_function_sum.h" + +#include "mongo/db/pipeline/document_source.h" +#include "mongo/db/pipeline/expression.h" + +namespace mongo { + +namespace { +// The input Value must be a vector of exactly two numeric Value. +bool validateValue(const Value& val) { + return (val.isArray() && val.getArray().size() == 2 && val.getArray()[0].numeric() && + val.getArray()[1].numeric()); +} + +// Convert the non-finite input value to a single Value that is then added to the underlying +// 'WindowFunctionSum'. +Value convertNonFiniteInputValue(Value value) { + int posCnt = 0, negCnt = 0, nanCnt = 0; + bool isDecimal = false; + for (auto val : value.getArray()) { + if (val.isNaN()) { + nanCnt++; + } else if (val.getType() == NumberDecimal) { + if (val.isInfinite()) + val.coerceToDecimal().isNegative() ? negCnt++ : posCnt++; + isDecimal = true; + } else if (val.numeric()) { + auto doubleVal = val.coerceToDouble(); + if (doubleVal == std::numeric_limits<double>::infinity()) + posCnt++; + else if (doubleVal == -std::numeric_limits<double>::infinity()) + negCnt++; + } + } + + // Should return NaN over Inf value if both NaN and Inf exist. + // Returns NaN if any NaN in 'value' or the two values are of different sign. + if (nanCnt > 0 || posCnt * negCnt > 0) + return isDecimal ? Value(Decimal128::kPositiveNaN) + : Value(std::numeric_limits<double>::quiet_NaN()); + + if (isDecimal) + return posCnt > 0 ? Value(Decimal128::kPositiveInfinity) + : Value(Decimal128::kNegativeInfinity); + else + return posCnt > 0 ? Value(std::numeric_limits<double>::infinity()) + : Value(-std::numeric_limits<double>::infinity()); +} +} // namespace + +WindowFunctionCovariance::WindowFunctionCovariance(ExpressionContext* const expCtx, bool isSamp) + : WindowFunctionState(expCtx), _isSamp(isSamp), _meanX(expCtx), _meanY(expCtx), _cXY(expCtx) { + _memUsageBytes = sizeof(*this); +} + +Value WindowFunctionCovariance::getValue() const { + if (_count == 1 && !_isSamp) + return Value(0.0); + + const double adjustedCount = (_isSamp ? _count - 1 : _count); + if (adjustedCount <= 0) + return kDefault; // Covariance not well defined in this case. + + auto output = _cXY.getValue(); + if (output.getType() == NumberDecimal) { + output = uassertStatusOK(ExpressionDivide::apply(output, Value(adjustedCount))); + } else if (output.numeric()) { + output = Value(output.coerceToDouble() / adjustedCount); + } + + return output; +} + +void WindowFunctionCovariance::add(Value value) { + // Not supported type of input have no impact on covariance. + if (!validateValue(value)) + return; + + const auto& arr = value.getArray(); + // The non-finite (nan/inf) value is handled by 'WindowFunctionSum' directly and is not taken + // into account when calculating the intermediate values and covariance. + if (arr[0].isNaN() || arr[1].isNaN() || arr[0].isInfinite() || arr[1].isInfinite()) { + auto infValue = convertNonFiniteInputValue(value); + _cXY.add(infValue); + return; + } + + _count++; + // Update covariance and means. + // This is an implementation of the following algorithm: + // https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Online + auto deltaX = uassertStatusOK(ExpressionSubtract::apply(arr[0], _meanX.getValue())); + _meanX.add(arr[0]); + _meanY.add(arr[1]); + auto deltaY = uassertStatusOK(ExpressionSubtract::apply(arr[1], _meanY.getValue())); + auto deltaCXY = uassertStatusOK(ExpressionMultiply::apply(deltaX, deltaY)); + _cXY.add(deltaCXY); +} + +void WindowFunctionCovariance::remove(Value value) { + // Not supported type of input have no impact on covariance. + if (!validateValue(value)) + return; + + const auto& arr = value.getArray(); + if (arr[0].isNaN() || arr[1].isNaN() || arr[0].isInfinite() || arr[1].isInfinite()) { + auto infValue = convertNonFiniteInputValue(value); + _cXY.remove(infValue); + return; + } + + tassert(5424100, "Can't remove from an empty WindowFunctionCovariance", _count > 0); + _count--; + if (_count == 0) { + reset(); + return; + } + + _meanX.remove(arr[0]); + auto deltaX = uassertStatusOK(ExpressionSubtract::apply(arr[0], _meanX.getValue())); + auto deltaY = uassertStatusOK(ExpressionSubtract::apply(arr[1], _meanY.getValue())); + auto deltaCXY = uassertStatusOK(ExpressionMultiply::apply(deltaX, deltaY)); + _cXY.remove(deltaCXY); + _meanY.remove(arr[1]); +} + +void WindowFunctionCovariance::reset() { + _count = 0; + _meanX.reset(); + _meanY.reset(); + _cXY.reset(); +} + +} // namespace mongo diff --git a/src/mongo/db/pipeline/window_function/window_function_covariance.h b/src/mongo/db/pipeline/window_function/window_function_covariance.h new file mode 100644 index 00000000000..a736be00c9c --- /dev/null +++ b/src/mongo/db/pipeline/window_function/window_function_covariance.h @@ -0,0 +1,86 @@ +/** + * Copyright (C) 2021-present MongoDB, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the Server Side Public License, version 1, + * as published by MongoDB, Inc. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Server Side Public License for more details. + * + * You should have received a copy of the Server Side Public License + * along with this program. If not, see + * <http://www.mongodb.com/licensing/server-side-public-license>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the Server Side Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#pragma once + +#include "mongo/db/pipeline/window_function/window_function.h" +#include "mongo/db/pipeline/window_function/window_function_avg.h" +#include "mongo/db/pipeline/window_function/window_function_sum.h" +#include "mongo/platform/decimal128.h" + +namespace mongo { + +class WindowFunctionCovariance : public WindowFunctionState { +public: + static inline const Value kDefault = Value(BSONNULL); + + WindowFunctionCovariance(ExpressionContext* const expCtx, bool isSamp); + + void add(Value value) override; + + void remove(Value value) override; + + void reset() override; + + Value getValue() const override; + + bool isSample() const { + return _isSamp; + } + +private: + bool _isSamp; + long long _count = 0; + + WindowFunctionAvg _meanX; + WindowFunctionAvg _meanY; + WindowFunctionSum _cXY; +}; + +class WindowFunctionCovarianceSamp final : public WindowFunctionCovariance { +public: + static std::unique_ptr<WindowFunctionState> create(ExpressionContext* const expCtx) { + return std::make_unique<WindowFunctionCovarianceSamp>(expCtx); + } + + explicit WindowFunctionCovarianceSamp(ExpressionContext* const expCtx) + : WindowFunctionCovariance(expCtx, true) {} +}; + +class WindowFunctionCovariancePop final : public WindowFunctionCovariance { +public: + static std::unique_ptr<WindowFunctionState> create(ExpressionContext* const expCtx) { + return std::make_unique<WindowFunctionCovariancePop>(expCtx); + } + + explicit WindowFunctionCovariancePop(ExpressionContext* const expCtx) + : WindowFunctionCovariance(expCtx, false) {} +}; + +} // namespace mongo diff --git a/src/mongo/db/pipeline/window_function/window_function_covariance_test.cpp b/src/mongo/db/pipeline/window_function/window_function_covariance_test.cpp new file mode 100644 index 00000000000..9d1b296a1ef --- /dev/null +++ b/src/mongo/db/pipeline/window_function/window_function_covariance_test.cpp @@ -0,0 +1,398 @@ +/** + * Copyright (C) 2021-present MongoDB, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the Server Side Public License, version 1, + * as published by MongoDB, Inc. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Server Side Public License for more details. + * + * You should have received a copy of the Server Side Public License + * along with this program. If not, see + * <http://www.mongodb.com/licensing/server-side-public-license>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the Server Side Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#include "mongo/platform/basic.h" + +#include "mongo/db/exec/document_value/document_value_test_util.h" +#include "mongo/db/pipeline/window_function/window_function.h" +#include "mongo/db/pipeline/window_function/window_function_covariance.h" +#include "mongo/unittest/unittest.h" + +namespace mongo { +namespace { + +class WindowFunctionCovarianceSampTest : public unittest::Test { +public: + WindowFunctionCovarianceSampTest() : covariance(nullptr) {} + + WindowFunctionCovarianceSamp covariance; +}; + +class WindowFunctionCovariancePopTest : public unittest::Test { +public: + WindowFunctionCovariancePopTest() : covariance(nullptr) {} + + WindowFunctionCovariancePop covariance; +}; + +void addToWindowCovariance(WindowFunctionCovariance* covariance, + const std::vector<Value>& valToAdd) { + for (auto val : valToAdd) { + covariance->add(val); + } +} + +// -------------- Test CovarianceSamp window function ---------- +TEST_F(WindowFunctionCovarianceSampTest, EmptyWindowShouldReturnNull) { + ASSERT_VALUE_EQ(covariance.getValue(), Value(BSONNULL)); +} + +TEST_F(WindowFunctionCovarianceSampTest, SingletonWindowShouldReturnNull) { + covariance.add(Value(std::vector<Value>({Value(1.0), Value(2.0)}))); + ASSERT_VALUE_EQ(covariance.getValue(), Value(BSONNULL)); +} + +TEST_F(WindowFunctionCovarianceSampTest, WindowAddition) { + const std::vector<Value> valToAdd = { + Value(std::vector<Value>({Value(0), Value(1.5)})), + Value(std::vector<Value>({Value(1.4), Value(2.5)})), + }; + addToWindowCovariance(&covariance, valToAdd); + + ASSERT_LTE(fabs(covariance.getValue().coerceToDouble() - 0.700000), 1e-5); + + // Test addition to the window correctly accumulate the result. + covariance.add(Value(std::vector<Value>({Value(4.7), Value(3.6)}))); + ASSERT_LTE(fabs(covariance.getValue().coerceToDouble() - 2.483334), 1e-5); +} + +TEST_F(WindowFunctionCovarianceSampTest, WindowRemoval) { + const std::vector<Value> values = { + Value(std::vector<Value>({Value(Decimal128(0)), Value(Decimal128(1.5))})), + Value(std::vector<Value>({Value(1.4), Value(2.5)})), + Value(std::vector<Value>({Value(4.7), Value(3.6)})), + }; + addToWindowCovariance(&covariance, values); + ASSERT_LTE(fabs(covariance.getValue().coerceToDouble() - 2.483334), 1e-5); + + covariance.remove(values[0]); + ASSERT_LTE(fabs(covariance.getValue().coerceToDouble() - 1.815000), 1e-5); + + // Adding back the value just removed should result in the same value as before. + covariance.add(values[0]); + ASSERT_LTE(fabs(covariance.getValue().coerceToDouble() - 2.483334), 1e-5); + covariance.remove(values[0]); + + covariance.remove(values[1]); + ASSERT_VALUE_EQ(covariance.getValue(), Value(BSONNULL)); + covariance.remove(values[2]); + ASSERT_VALUE_EQ(covariance.getValue(), Value(BSONNULL)); +} + +TEST_F(WindowFunctionCovarianceSampTest, CanHandleNaN) { + std::vector<Value> values = { + Value(std::vector<Value>({Value(std::numeric_limits<double>::quiet_NaN()), + Value(std::numeric_limits<double>::quiet_NaN())})), + Value(std::vector<Value>({Value(1.0), Value(2.0)})), + Value(std::vector<Value>({Value(2.0), Value(4.0)})), + }; + addToWindowCovariance(&covariance, values); + // The window contains NaN value, so the result should be NaN. + ASSERT_VALUE_EQ(covariance.getValue(), Value(std::numeric_limits<double>::quiet_NaN())); + + covariance.remove(values[0]); // Remove the NaN value in the window. + ASSERT_LTE(fabs(covariance.getValue().coerceToDouble() - 1.0), 1e-5); + + covariance.reset(); + + values = std::vector<Value>({ + Value( + std::vector<Value>({Value(Decimal128::kPositiveNaN), Value(Decimal128::kPositiveNaN)})), + Value(std::vector<Value>({Value(1.0), Value(2.0)})), + Value(std::vector<Value>({Value(2.0), Value(4.0)})), + }); + addToWindowCovariance(&covariance, values); + ASSERT_VALUE_EQ(covariance.getValue(), Value(Decimal128::kPositiveNaN)); + + covariance.remove(values[0]); // Remove the NaN value in the window. + ASSERT_LTE(fabs(covariance.getValue().coerceToDouble() - 1.0), 1e-5); +} + +TEST_F(WindowFunctionCovarianceSampTest, CanHandleInfinity) { + // Test double infinity. + std::vector<Value> values = { + Value(std::vector<Value>({Value(-std::numeric_limits<double>::infinity()), + Value(std::numeric_limits<double>::infinity())})), + Value(std::vector<Value>({Value(std::numeric_limits<double>::infinity()), + Value(std::numeric_limits<double>::infinity())})), + Value(std::vector<Value>({Value(1.0), Value(2.0)})), + Value(std::vector<Value>({Value(2.0), Value(4.0)})), + }; + addToWindowCovariance(&covariance, values); + // The window contains infinity values of different sign, so the result should be NaN. + ASSERT_VALUE_EQ(covariance.getValue(), Value(std::numeric_limits<double>::quiet_NaN())); + + // Remove the NaN/infinite value in the window. The remaining inf value should be resolved to + // the corresponding inf value. + covariance.remove(values[0]); + ASSERT_VALUE_EQ(covariance.getValue(), Value(std::numeric_limits<double>::infinity())); + + covariance.remove(values[1]); + ASSERT_LTE(fabs(covariance.getValue().coerceToDouble() - 1.0), 1e-5); + + covariance.reset(); + + // Test Decimal128 infinity. + values = { + Value(std::vector<Value>( + {Value(Decimal128::kNegativeInfinity), Value(Decimal128::kPositiveNaN)})), + Value( + std::vector<Value>({Value(Decimal128::kPositiveNaN), Value(Decimal128::kPositiveNaN)})), + Value(std::vector<Value>({Value(1.0), Value(2.0)})), + Value(std::vector<Value>({Value(2.0), Value(4.0)})), + }; + addToWindowCovariance(&covariance, values); + // The window contains infinity values of different sign, so the result should be NaN. + ASSERT_VALUE_EQ(covariance.getValue(), Value(Decimal128::kPositiveNaN)); + + // Remove the NaN/infinite value in the window. The remaining inf value should be resolved to + // the corresponding inf value. + covariance.remove(values[0]); + ASSERT_VALUE_EQ(covariance.getValue(), Value(Decimal128::kPositiveNaN)); + + covariance.remove(values[1]); + ASSERT_LTE(fabs(covariance.getValue().coerceToDouble() - 1.0), 1e-5); +} + +TEST_F(WindowFunctionCovarianceSampTest, ReturnNaNOverInfIfExistsBoth) { + std::vector<Value> values = { + Value(std::vector<Value>({Value(std::numeric_limits<double>::quiet_NaN()), + Value(std::numeric_limits<double>::quiet_NaN())})), + Value(std::vector<Value>({Value(std::numeric_limits<double>::infinity()), + Value(std::numeric_limits<double>::infinity())})), + Value(std::vector<Value>({Value(1.0), Value(2.0)})), + Value(std::vector<Value>({Value(2.0), Value(4.0)})), + }; + addToWindowCovariance(&covariance, values); + // When the window contains both infinity and NaN, the result should be NaN rather than Inf. + ASSERT_VALUE_EQ(covariance.getValue(), Value(std::numeric_limits<double>::quiet_NaN())); + + covariance.remove(values[0]); + ASSERT_VALUE_EQ(covariance.getValue(), Value(std::numeric_limits<double>::infinity())); + covariance.remove(values[1]); + ASSERT_LTE(fabs(covariance.getValue().coerceToDouble() - 1.0), 1e-5); +} + +// -------------- Test CovariancePop window function ---------- +TEST_F(WindowFunctionCovariancePopTest, EmptyWindowShouldReturnNull) { + ASSERT_VALUE_EQ(covariance.getValue(), Value(BSONNULL)); +} + +TEST_F(WindowFunctionCovariancePopTest, SingletonWindowShouldReturnZero) { + covariance.add(Value(std::vector<Value>({Value(1.0), Value(2.0)}))); + ASSERT_VALUE_EQ(covariance.getValue(), Value(0.0)); +} + +TEST_F(WindowFunctionCovariancePopTest, WindowAddition) { + const std::vector<Value> valToAdd = { + Value(std::vector<Value>({Value(0), Value(1.5)})), + Value(std::vector<Value>({Value(1.4), Value(2.5)})), + }; + addToWindowCovariance(&covariance, valToAdd); + + ASSERT_LTE(fabs(covariance.getValue().coerceToDouble() - 0.350000), 1e-5); + + // Test addition to the window correctly accumulate the result. + covariance.add(Value(std::vector<Value>({Value(4.7), Value(3.6)}))); + ASSERT_LTE(fabs(covariance.getValue().coerceToDouble() - 1.655556), 1e-5); +} + +TEST_F(WindowFunctionCovariancePopTest, WindowRemoval) { + const std::vector<Value> values = { + Value(std::vector<Value>({Value(0), Value(1.5)})), + Value(std::vector<Value>({Value(Decimal128(1.4)), Value(Decimal128(2.5))})), + Value(std::vector<Value>({Value(4.7), Value(3.6)})), + }; + addToWindowCovariance(&covariance, values); + ASSERT_LTE(fabs(covariance.getValue().coerceToDouble() - 1.655556), 1e-5); + + covariance.remove(values[0]); + ASSERT_LTE(fabs(covariance.getValue().coerceToDouble() - 0.907500), 1e-5); + + // Adding back the value just removed should result in the same value as before. + covariance.add(values[0]); + ASSERT_LTE(fabs(covariance.getValue().coerceToDouble() - 1.655556), 1e-5); + covariance.remove(values[0]); + + covariance.remove(values[1]); + ASSERT_LTE(fabs(covariance.getValue().coerceToDouble() - 0.0), 1e-5); + covariance.remove(values[2]); + ASSERT_VALUE_EQ(covariance.getValue(), Value(BSONNULL)); +} + +TEST_F(WindowFunctionCovariancePopTest, CanHandleNaN) { + std::vector<Value> values = { + Value(std::vector<Value>({Value(std::numeric_limits<double>::quiet_NaN()), + Value(std::numeric_limits<double>::quiet_NaN())})), + Value(std::vector<Value>({Value(1.0), Value(2.0)})), + Value(std::vector<Value>({Value(2.0), Value(4.0)})), + }; + addToWindowCovariance(&covariance, values); + // The window contains NaN value, so the result should be NaN. + ASSERT_VALUE_EQ(covariance.getValue(), Value(std::numeric_limits<double>::quiet_NaN())); + + covariance.remove(values[0]); // Remove the NaN value in the window. + ASSERT_LTE(fabs(covariance.getValue().coerceToDouble() - 0.5), 1e-5); + + covariance.reset(); + + values = std::vector<Value>({ + Value( + std::vector<Value>({Value(Decimal128::kPositiveNaN), Value(Decimal128::kPositiveNaN)})), + Value(std::vector<Value>({Value(1.0), Value(2.0)})), + Value(std::vector<Value>({Value(2.0), Value(4.0)})), + }); + addToWindowCovariance(&covariance, values); + ASSERT_VALUE_EQ(covariance.getValue(), Value(Decimal128::kPositiveNaN)); + + covariance.remove(values[0]); // Remove the NaN value in the window. + ASSERT_LTE(fabs(covariance.getValue().coerceToDouble() - 0.5), 1e-5); +} + +TEST_F(WindowFunctionCovariancePopTest, CanHandleInfinity) { + // Test double infinity. + std::vector<Value> values = { + Value(std::vector<Value>({Value(-std::numeric_limits<double>::infinity()), + Value(std::numeric_limits<double>::infinity())})), + Value(std::vector<Value>({Value(std::numeric_limits<double>::infinity()), + Value(std::numeric_limits<double>::infinity())})), + Value(std::vector<Value>({Value(1.0), Value(2.0)})), + Value(std::vector<Value>({Value(2.0), Value(4.0)})), + }; + addToWindowCovariance(&covariance, values); + // The window contains infinity values of different sign, so the result should be NaN. + ASSERT_VALUE_EQ(covariance.getValue(), Value(std::numeric_limits<double>::quiet_NaN())); + + // Remove the NaN/infinite value in the window. The remaining inf value should be resolved to + // the corresponding inf value. + covariance.remove(values[0]); + ASSERT_VALUE_EQ(covariance.getValue(), Value(std::numeric_limits<double>::infinity())); + + covariance.remove(values[1]); + ASSERT_LTE(fabs(covariance.getValue().coerceToDouble() - 0.5), 1e-5); + + covariance.reset(); + + // Test Decimal128 infinity. + values = { + Value(std::vector<Value>( + {Value(Decimal128::kNegativeInfinity), Value(Decimal128::kPositiveInfinity)})), + Value(std::vector<Value>( + {Value(Decimal128::kPositiveInfinity), Value(Decimal128::kPositiveInfinity)})), + Value(std::vector<Value>({Value(1.0), Value(2.0)})), + Value(std::vector<Value>({Value(2.0), Value(4.0)})), + }; + addToWindowCovariance(&covariance, values); + // The window contains infinity values of different sign, so the result should be NaN. + ASSERT_VALUE_EQ(covariance.getValue(), Value(Decimal128::kPositiveNaN)); + + // Remove the NaN/infinite value in the window. The remaining inf value should be resolved to + // the corresponding inf value. + covariance.remove(values[0]); + ASSERT_VALUE_EQ(covariance.getValue(), Value(Decimal128::kPositiveInfinity)); + + covariance.remove(values[1]); + ASSERT_LTE(fabs(covariance.getValue().coerceToDouble() - 0.5), 1e-5); +} + +TEST_F(WindowFunctionCovariancePopTest, ReturnNaNOverInfIfExistsBoth) { + std::vector<Value> values = { + Value(std::vector<Value>({Value(std::numeric_limits<double>::quiet_NaN()), + Value(std::numeric_limits<double>::quiet_NaN())})), + Value(std::vector<Value>({Value(std::numeric_limits<double>::infinity()), + Value(std::numeric_limits<double>::infinity())})), + Value(std::vector<Value>({Value(1.0), Value(2.0)})), + Value(std::vector<Value>({Value(2.0), Value(4.0)})), + }; + addToWindowCovariance(&covariance, values); + // When the window contains both infinity and NaN, the result should be NaN rather than Inf. + ASSERT_VALUE_EQ(covariance.getValue(), Value(std::numeric_limits<double>::quiet_NaN())); + + covariance.remove(values[0]); + ASSERT_VALUE_EQ(covariance.getValue(), Value(std::numeric_limits<double>::infinity())); + covariance.remove(values[1]); + ASSERT_LTE(fabs(covariance.getValue().coerceToDouble() - 0.5), 1e-5); +} + +TEST_F(WindowFunctionCovariancePopTest, NonNumericTypesHaveNoImpactOnCovariance) { + const std::string str = "non-numeric-type"; + const std::vector<Value> values = { + // Numeric type check is before NaN check, so this value should not cause NaN result. + Value(std::vector<Value>({Value(std::numeric_limits<double>::quiet_NaN()), Value(str)})), + Value(std::vector<Value>({Value(1.0), Value(2.0)})), + Value(std::vector<Value>({Value(str), Value(2.0)})), + Value(std::vector<Value>({Value(1.0), Value(BSONNULL)})), + Value(std::vector<Value>({Value(2.0), Value(4.0)})), + Value(std::vector<Value>({Value(str), Value(BSONNULL)})), + }; + addToWindowCovariance(&covariance, values); + + // Note that non-numeric input simply has no impact on coavariance and won't throw or fail the + // computation. + + // Only numeric values 'values[1]' and 'values[4]' should be considered "valid". + ASSERT_LTE(fabs(covariance.getValue().coerceToDouble() - 0.5), 1e-5); + + // Removing a non-numeric value, covariance should remain the same. + covariance.remove(values[0]); + ASSERT_LTE(fabs(covariance.getValue().coerceToDouble() - 0.5), 1e-5); + + // Remove a numeric value. + covariance.remove(values[1]); + ASSERT_VALUE_EQ(covariance.getValue(), Value(0.0)); +} + +TEST_F(WindowFunctionCovariancePopTest, NonDecimalNumericResultShouldBeCoercedToDouble) { + ASSERT_VALUE_EQ(covariance.getValue(), Value(BSONNULL)); + covariance.add(Value(std::vector<Value>({Value(0), Value(1)}))); + + ASSERT_EQUALS(covariance.getValue().getType(), NumberDouble); + ASSERT_VALUE_EQ(covariance.getValue(), Value(0.0)); + + covariance.add(Value(std::vector<Value>({Value(1), Value(2)}))); + ASSERT_EQUALS(covariance.getValue().getType(), NumberDouble); +} + +TEST_F(WindowFunctionCovariancePopTest, WidenTypeToDecimalOnlyIfNeeded) { + const std::vector<Value> values = { + Value(std::vector<Value>({Value(0), Value(1.5)})), + Value(std::vector<Value>({Value(1.4), Value(2.5)})), + }; + addToWindowCovariance(&covariance, values); + + ASSERT_EQUALS(covariance.getValue().getType(), NumberDouble); + ASSERT_LTE(fabs(covariance.getValue().coerceToDouble() - 0.350000), 1e-5); + + covariance.add(Value(std::vector<Value>({Value(Decimal128(4.7)), Value(Decimal128(3.6))}))); + ASSERT_EQUALS(covariance.getValue().getType(), NumberDecimal); + ASSERT_LTE(fabs(covariance.getValue().coerceToDouble() - 1.655556), 1e-5); +} + +} // namespace +} // namespace mongo diff --git a/src/mongo/db/pipeline/window_function/window_function_min_max.h b/src/mongo/db/pipeline/window_function/window_function_min_max.h index e450caa6bc1..0ff212484c0 100644 --- a/src/mongo/db/pipeline/window_function/window_function_min_max.h +++ b/src/mongo/db/pipeline/window_function/window_function_min_max.h @@ -29,6 +29,7 @@ #pragma once +#include "mongo/db/pipeline/accumulator.h" #include "mongo/db/pipeline/window_function/window_function.h" namespace mongo { diff --git a/src/mongo/db/pipeline/window_function/window_function_stddev.h b/src/mongo/db/pipeline/window_function/window_function_stddev.h index 5d818c06479..ab4d5c33b2b 100644 --- a/src/mongo/db/pipeline/window_function/window_function_stddev.h +++ b/src/mongo/db/pipeline/window_function/window_function_stddev.h @@ -29,6 +29,7 @@ #pragma once +#include "mongo/db/pipeline/accumulator.h" #include "mongo/db/pipeline/window_function/window_function.h" namespace mongo { |