diff options
author | Hartek Sabharwal <hartek.sabharwal@mongodb.com> | 2021-03-05 17:23:04 +0000 |
---|---|---|
committer | Evergreen Agent <no-reply@evergreen.mongodb.com> | 2021-03-05 21:19:16 +0000 |
commit | 8fc03988ff0cdf087b9108aff3a649e59f45cddb (patch) | |
tree | b4893ae8f2533ac5596996b36f05f40643962a17 /src/mongo/db/pipeline | |
parent | 69027dee0ae53ef10a0adcc324d06ef12d0f634f (diff) | |
download | mongo-8fc03988ff0cdf087b9108aff3a649e59f45cddb.tar.gz |
SERVER-54234 Implement removable $stddev window function
Diffstat (limited to 'src/mongo/db/pipeline')
-rw-r--r-- | src/mongo/db/pipeline/SConscript | 1 | ||||
-rw-r--r-- | src/mongo/db/pipeline/window_function/window_function.h | 88 | ||||
-rw-r--r-- | src/mongo/db/pipeline/window_function/window_function_std_dev_test.cpp | 201 |
3 files changed, 290 insertions, 0 deletions
diff --git a/src/mongo/db/pipeline/SConscript b/src/mongo/db/pipeline/SConscript index 6ac3ef693bf..f276e6a84b8 100644 --- a/src/mongo/db/pipeline/SConscript +++ b/src/mongo/db/pipeline/SConscript @@ -433,6 +433,7 @@ env.CppUnitTest( 'window_function/window_function_exec_non_removable_test.cpp', 'window_function/window_function_min_max_test.cpp', 'window_function/window_function_push_test.cpp', + 'window_function/window_function_std_dev_test.cpp', ], LIBDEPS=[ '$BUILD_DIR/mongo/base', diff --git a/src/mongo/db/pipeline/window_function/window_function.h b/src/mongo/db/pipeline/window_function/window_function.h index b8feba5a2ef..c2d0979e5b2 100644 --- a/src/mongo/db/pipeline/window_function/window_function.h +++ b/src/mongo/db/pipeline/window_function/window_function.h @@ -206,4 +206,92 @@ private: // insertion. std::list<Value> _list; }; + +class WindowFunctionStdDev : public WindowFunctionState { +protected: + explicit WindowFunctionStdDev(ExpressionContext* const expCtx, bool isSamp) + : WindowFunctionState(expCtx), + _sum(AccumulatorSum::create(expCtx)), + _m2(AccumulatorSum::create(expCtx)), + _isSamp(isSamp), + _count(0), + _nonfiniteValueCount(0) {} + +public: + static Value getDefault() { + return Value(BSONNULL); + } + + void add(Value value) { + update(std::move(value), +1); + } + + void remove(Value value) { + update(std::move(value), -1); + } + + Value getValue() const final { + if (_nonfiniteValueCount > 0) + return Value(std::numeric_limits<double>::quiet_NaN()); + const long long adjustedCount = _isSamp ? _count - 1 : _count; + if (adjustedCount == 0) + return getDefault(); + return Value(sqrt(_m2->getValue(false).coerceToDouble() / adjustedCount)); + } + + void reset() { + _m2->reset(); + _sum->reset(); + _count = 0; + _nonfiniteValueCount = 0; + } + +private: + void update(Value value, int quantity) { + // quantity should be 1 if adding value, -1 if removing value + if (!value.numeric()) + return; + if ((value.getType() == NumberDouble && !std::isfinite(value.getDouble())) || + (value.getType() == NumberDecimal && !value.getDecimal().isFinite())) { + _nonfiniteValueCount += quantity; + _count += quantity; + return; + } + + if (_count == 0) { // Assuming we are adding value if _count == 0. + _count++; + _sum->process(value, false); + return; + } else if (_count + quantity == 0) { // Empty the window. + reset(); + return; + } + double x = _count * value.coerceToDouble() - _sum->getValue(false).coerceToDouble(); + _count += quantity; + _sum->process(Value{value.coerceToDouble() * quantity}, false); + _m2->process(Value{x * x * quantity / (_count * (_count - quantity))}, false); + } + + // Std dev cannot make use of RemovableSum because of its specific handling of non-finite + // values. Adding a NaN or +/-inf makes the result NaN. Additionally, the consistent and + // exclusive use of doubles in std dev calculations makes the type handling in RemovableSum + // unnecessary. + boost::intrusive_ptr<AccumulatorState> _sum; + boost::intrusive_ptr<AccumulatorState> _m2; + bool _isSamp; + long long _count; + int _nonfiniteValueCount; +}; + +class WindowFunctionStdDevPop final : public WindowFunctionStdDev { +public: + explicit WindowFunctionStdDevPop(ExpressionContext* const expCtx) + : WindowFunctionStdDev(expCtx, false) {} +}; + +class WindowFunctionStdDevSamp final : public WindowFunctionStdDev { +public: + explicit WindowFunctionStdDevSamp(ExpressionContext* const expCtx) + : WindowFunctionStdDev(expCtx, true) {} +}; } // namespace mongo diff --git a/src/mongo/db/pipeline/window_function/window_function_std_dev_test.cpp b/src/mongo/db/pipeline/window_function/window_function_std_dev_test.cpp new file mode 100644 index 00000000000..42ccb284380 --- /dev/null +++ b/src/mongo/db/pipeline/window_function/window_function_std_dev_test.cpp @@ -0,0 +1,201 @@ +/** + * Copyright (C) 2021-present MongoDB, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the Server Side Public License, version 1, + * as published by MongoDB, Inc. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Server Side Public License for more details. + * + * You should have received a copy of the Server Side Public License + * along with this program. If not, see + * <http://www.mongodb.com/licensing/server-side-public-license>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the Server Side Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#include "mongo/platform/basic.h" + +#include "mongo/db/exec/document_value/document_value_test_util.h" +#include "mongo/db/pipeline/window_function/window_function.h" +#include "mongo/db/query/collation/collator_interface_mock.h" +#include "mongo/unittest/unittest.h" + +namespace mongo { +namespace { + +class WindowFunctionStdDevTest : public unittest::Test { +public: + WindowFunctionStdDevTest() + : pop(WindowFunctionStdDevPop(nullptr)), samp(WindowFunctionStdDevSamp(nullptr)) {} + + WindowFunctionStdDevPop pop; + WindowFunctionStdDevSamp samp; + + // Two-pass algorithm + double stdDevPop(std::vector<double>::const_iterator begin, + std::vector<double>::const_iterator end) { + double mean = std::accumulate(begin, end, 0.0) / (end - begin); + double squaredDiffs = std::accumulate(begin, end, 0.0, [&](double acc, double val) { + return acc + (val - mean) * (val - mean); + }); + return sqrt(squaredDiffs / (end - begin)); + } +}; + +TEST_F(WindowFunctionStdDevTest, EmptyWindow) { + ASSERT_VALUE_EQ(pop.getValue(), Value{BSONNULL}); +} + +TEST_F(WindowFunctionStdDevTest, ReturnsDouble) { + pop.add(Value{1}); + pop.add(Value{2}); + pop.add(Value{3}); + ASSERT_EQ(pop.getValue().getType(), NumberDouble); + + samp.add(Value{1}); + samp.add(Value{2}); + samp.add(Value{3}); + // Returns 1.0 + ASSERT_EQ(samp.getValue().getType(), NumberDouble); + + pop.add(Value{Decimal128("100000000000000000000000000000")}); + ASSERT_EQ(pop.getValue().getType(), NumberDouble); +} + + +TEST_F(WindowFunctionStdDevTest, Add) { + pop.add(Value{1}); + pop.add(Value{2}); + pop.add(Value{3}); + ASSERT_VALUE_EQ(pop.getValue(), Value{sqrt(2 / 3.0)}); + + samp.add(Value{1}); + samp.add(Value{2}); + samp.add(Value{3}); + ASSERT_VALUE_EQ(samp.getValue(), Value{1.0}); +} + +TEST_F(WindowFunctionStdDevTest, Remove1) { + pop.add(Value{1}); + pop.add(Value{2}); + pop.add(Value{3}); + // Add, then remove + pop.add(Value{4}); + pop.remove(Value{1}); + ASSERT_VALUE_EQ(pop.getValue(), Value{sqrt(2 / 3.0)}); +} + +TEST_F(WindowFunctionStdDevTest, Remove2) { + pop.add(Value{1}); + pop.add(Value{2}); + pop.add(Value{3}); + // Remove, then add + pop.remove(Value{1}); + pop.add(Value{4}); + ASSERT_VALUE_EQ(pop.getValue(), Value{sqrt(2 / 3.0)}); +} + +TEST_F(WindowFunctionStdDevTest, SampleRemove) { + samp.add(Value{1}); + samp.add(Value{2}); + samp.add(Value{3}); + samp.remove(Value{1}); + ASSERT_VALUE_EQ(samp.getValue(), Value{sqrt(0.5)}); +} + +TEST_F(WindowFunctionStdDevTest, NotDividingByZeroInM2Update) { + pop.add(Value{1}); + pop.remove(Value{1}); + pop.add(Value{1}); + pop.add(Value{2}); + ASSERT_VALUE_EQ(pop.getValue(), Value{0.5}); + + double nan = std::numeric_limits<double>::quiet_NaN(); + samp.add(Value{nan}); + samp.remove(Value{nan}); + samp.add(Value{1}); + samp.add(Value{2}); + ASSERT_VALUE_EQ(samp.getValue(), Value{sqrt(0.5)}); +} + +TEST_F(WindowFunctionStdDevTest, HandlesNonfinite) { + double inf = std::numeric_limits<double>::infinity(); + double nan = std::numeric_limits<double>::quiet_NaN(); + + pop.add(Value{1}); + pop.add(Value{2}); + pop.add(Value{inf}); + ASSERT_VALUE_EQ(pop.getValue(), Value{nan}); // 1, 2, inf + pop.remove(Value{inf}); + ASSERT_EQ(pop.getValue().getDouble(), 0.5); // 1, 2 + pop.add(Value{nan}); + ASSERT_VALUE_EQ(pop.getValue(), Value{nan}); // 1, 2, nan + pop.remove(Value{nan}); + pop.add(Value{-inf}); + ASSERT_VALUE_EQ(pop.getValue(), Value{nan}); // 1, 2, -inf + pop.add(Value{inf}); + ASSERT_VALUE_EQ(pop.getValue(), Value{nan}); // 1, 2, -inf, inf + pop.add(Value{nan}); + ASSERT_VALUE_EQ(pop.getValue(), Value{nan}); // 1, 2, -inf, inf, nan +} + +TEST_F(WindowFunctionStdDevTest, Stability) { + const int collLength = 10000; + const int windowSize = 100; + PseudoRandom prng(0); + std::vector<double> vec(collLength); + for (int j = 0; j < collLength; j++) { + vec[j] = prng.nextCanonicalDouble() - 0.5; + } + for (int i = 0; i < windowSize; i++) { + pop.add(Value{vec[i]}); + } + for (int i = windowSize; i < collLength; i++) { + pop.add(Value{vec[i]}); + pop.remove(Value{vec[i - windowSize]}); + double trueStdDev = stdDevPop(vec.begin() + i - windowSize + 1, vec.begin() + i + 1); + double calculatedStdDev = pop.getValue().getDouble(); + ASSERT_LTE(Decimal128(calculatedStdDev).subtract(Decimal128(trueStdDev)).toAbs(), + Decimal128("1e-15")); + double relativeError = (calculatedStdDev - trueStdDev) / trueStdDev; + ASSERT_LTE(relativeError, 1e-15); + } +} + +TEST_F(WindowFunctionStdDevTest, LargeNumberStability) { + const int collLength = 10000; + const int windowSize = 100; + PseudoRandom prng(0); + std::vector<double> vec(collLength); + for (int j = 0; j < collLength; j++) { + vec[j] = (prng.nextCanonicalDouble() - 0.5) * prng.nextInt64(); + } + for (int i = 0; i < windowSize; i++) { + pop.add(Value{vec[i]}); + } + for (int i = windowSize; i < collLength; i++) { + pop.add(Value{vec[i]}); + pop.remove(Value{vec[i - windowSize]}); + double trueStdDev = stdDevPop(vec.begin() + i - windowSize + 1, vec.begin() + i + 1); + double calculatedStdDev = pop.getValue().getDouble(); + double relativeError = (calculatedStdDev - trueStdDev) / trueStdDev; + ASSERT_LTE(relativeError, 1e-15); + } +} + +} // namespace +} // namespace mongo |