summaryrefslogtreecommitdiff
path: root/src/mongo/db/pipeline
diff options
context:
space:
mode:
authorHartek Sabharwal <hartek.sabharwal@mongodb.com>2021-03-05 17:23:04 +0000
committerEvergreen Agent <no-reply@evergreen.mongodb.com>2021-03-05 21:19:16 +0000
commit8fc03988ff0cdf087b9108aff3a649e59f45cddb (patch)
treeb4893ae8f2533ac5596996b36f05f40643962a17 /src/mongo/db/pipeline
parent69027dee0ae53ef10a0adcc324d06ef12d0f634f (diff)
downloadmongo-8fc03988ff0cdf087b9108aff3a649e59f45cddb.tar.gz
SERVER-54234 Implement removable $stddev window function
Diffstat (limited to 'src/mongo/db/pipeline')
-rw-r--r--src/mongo/db/pipeline/SConscript1
-rw-r--r--src/mongo/db/pipeline/window_function/window_function.h88
-rw-r--r--src/mongo/db/pipeline/window_function/window_function_std_dev_test.cpp201
3 files changed, 290 insertions, 0 deletions
diff --git a/src/mongo/db/pipeline/SConscript b/src/mongo/db/pipeline/SConscript
index 6ac3ef693bf..f276e6a84b8 100644
--- a/src/mongo/db/pipeline/SConscript
+++ b/src/mongo/db/pipeline/SConscript
@@ -433,6 +433,7 @@ env.CppUnitTest(
'window_function/window_function_exec_non_removable_test.cpp',
'window_function/window_function_min_max_test.cpp',
'window_function/window_function_push_test.cpp',
+ 'window_function/window_function_std_dev_test.cpp',
],
LIBDEPS=[
'$BUILD_DIR/mongo/base',
diff --git a/src/mongo/db/pipeline/window_function/window_function.h b/src/mongo/db/pipeline/window_function/window_function.h
index b8feba5a2ef..c2d0979e5b2 100644
--- a/src/mongo/db/pipeline/window_function/window_function.h
+++ b/src/mongo/db/pipeline/window_function/window_function.h
@@ -206,4 +206,92 @@ private:
// insertion.
std::list<Value> _list;
};
+
+class WindowFunctionStdDev : public WindowFunctionState {
+protected:
+ explicit WindowFunctionStdDev(ExpressionContext* const expCtx, bool isSamp)
+ : WindowFunctionState(expCtx),
+ _sum(AccumulatorSum::create(expCtx)),
+ _m2(AccumulatorSum::create(expCtx)),
+ _isSamp(isSamp),
+ _count(0),
+ _nonfiniteValueCount(0) {}
+
+public:
+ static Value getDefault() {
+ return Value(BSONNULL);
+ }
+
+ void add(Value value) {
+ update(std::move(value), +1);
+ }
+
+ void remove(Value value) {
+ update(std::move(value), -1);
+ }
+
+ Value getValue() const final {
+ if (_nonfiniteValueCount > 0)
+ return Value(std::numeric_limits<double>::quiet_NaN());
+ const long long adjustedCount = _isSamp ? _count - 1 : _count;
+ if (adjustedCount == 0)
+ return getDefault();
+ return Value(sqrt(_m2->getValue(false).coerceToDouble() / adjustedCount));
+ }
+
+ void reset() {
+ _m2->reset();
+ _sum->reset();
+ _count = 0;
+ _nonfiniteValueCount = 0;
+ }
+
+private:
+ void update(Value value, int quantity) {
+ // quantity should be 1 if adding value, -1 if removing value
+ if (!value.numeric())
+ return;
+ if ((value.getType() == NumberDouble && !std::isfinite(value.getDouble())) ||
+ (value.getType() == NumberDecimal && !value.getDecimal().isFinite())) {
+ _nonfiniteValueCount += quantity;
+ _count += quantity;
+ return;
+ }
+
+ if (_count == 0) { // Assuming we are adding value if _count == 0.
+ _count++;
+ _sum->process(value, false);
+ return;
+ } else if (_count + quantity == 0) { // Empty the window.
+ reset();
+ return;
+ }
+ double x = _count * value.coerceToDouble() - _sum->getValue(false).coerceToDouble();
+ _count += quantity;
+ _sum->process(Value{value.coerceToDouble() * quantity}, false);
+ _m2->process(Value{x * x * quantity / (_count * (_count - quantity))}, false);
+ }
+
+ // Std dev cannot make use of RemovableSum because of its specific handling of non-finite
+ // values. Adding a NaN or +/-inf makes the result NaN. Additionally, the consistent and
+ // exclusive use of doubles in std dev calculations makes the type handling in RemovableSum
+ // unnecessary.
+ boost::intrusive_ptr<AccumulatorState> _sum;
+ boost::intrusive_ptr<AccumulatorState> _m2;
+ bool _isSamp;
+ long long _count;
+ int _nonfiniteValueCount;
+};
+
+class WindowFunctionStdDevPop final : public WindowFunctionStdDev {
+public:
+ explicit WindowFunctionStdDevPop(ExpressionContext* const expCtx)
+ : WindowFunctionStdDev(expCtx, false) {}
+};
+
+class WindowFunctionStdDevSamp final : public WindowFunctionStdDev {
+public:
+ explicit WindowFunctionStdDevSamp(ExpressionContext* const expCtx)
+ : WindowFunctionStdDev(expCtx, true) {}
+};
} // namespace mongo
diff --git a/src/mongo/db/pipeline/window_function/window_function_std_dev_test.cpp b/src/mongo/db/pipeline/window_function/window_function_std_dev_test.cpp
new file mode 100644
index 00000000000..42ccb284380
--- /dev/null
+++ b/src/mongo/db/pipeline/window_function/window_function_std_dev_test.cpp
@@ -0,0 +1,201 @@
+/**
+ * Copyright (C) 2021-present MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the Server Side Public License, version 1,
+ * as published by MongoDB, Inc.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * Server Side Public License for more details.
+ *
+ * You should have received a copy of the Server Side Public License
+ * along with this program. If not, see
+ * <http://www.mongodb.com/licensing/server-side-public-license>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the Server Side Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#include "mongo/platform/basic.h"
+
+#include "mongo/db/exec/document_value/document_value_test_util.h"
+#include "mongo/db/pipeline/window_function/window_function.h"
+#include "mongo/db/query/collation/collator_interface_mock.h"
+#include "mongo/unittest/unittest.h"
+
+namespace mongo {
+namespace {
+
+class WindowFunctionStdDevTest : public unittest::Test {
+public:
+ WindowFunctionStdDevTest()
+ : pop(WindowFunctionStdDevPop(nullptr)), samp(WindowFunctionStdDevSamp(nullptr)) {}
+
+ WindowFunctionStdDevPop pop;
+ WindowFunctionStdDevSamp samp;
+
+ // Two-pass algorithm
+ double stdDevPop(std::vector<double>::const_iterator begin,
+ std::vector<double>::const_iterator end) {
+ double mean = std::accumulate(begin, end, 0.0) / (end - begin);
+ double squaredDiffs = std::accumulate(begin, end, 0.0, [&](double acc, double val) {
+ return acc + (val - mean) * (val - mean);
+ });
+ return sqrt(squaredDiffs / (end - begin));
+ }
+};
+
+TEST_F(WindowFunctionStdDevTest, EmptyWindow) {
+ ASSERT_VALUE_EQ(pop.getValue(), Value{BSONNULL});
+}
+
+TEST_F(WindowFunctionStdDevTest, ReturnsDouble) {
+ pop.add(Value{1});
+ pop.add(Value{2});
+ pop.add(Value{3});
+ ASSERT_EQ(pop.getValue().getType(), NumberDouble);
+
+ samp.add(Value{1});
+ samp.add(Value{2});
+ samp.add(Value{3});
+ // Returns 1.0
+ ASSERT_EQ(samp.getValue().getType(), NumberDouble);
+
+ pop.add(Value{Decimal128("100000000000000000000000000000")});
+ ASSERT_EQ(pop.getValue().getType(), NumberDouble);
+}
+
+
+TEST_F(WindowFunctionStdDevTest, Add) {
+ pop.add(Value{1});
+ pop.add(Value{2});
+ pop.add(Value{3});
+ ASSERT_VALUE_EQ(pop.getValue(), Value{sqrt(2 / 3.0)});
+
+ samp.add(Value{1});
+ samp.add(Value{2});
+ samp.add(Value{3});
+ ASSERT_VALUE_EQ(samp.getValue(), Value{1.0});
+}
+
+TEST_F(WindowFunctionStdDevTest, Remove1) {
+ pop.add(Value{1});
+ pop.add(Value{2});
+ pop.add(Value{3});
+ // Add, then remove
+ pop.add(Value{4});
+ pop.remove(Value{1});
+ ASSERT_VALUE_EQ(pop.getValue(), Value{sqrt(2 / 3.0)});
+}
+
+TEST_F(WindowFunctionStdDevTest, Remove2) {
+ pop.add(Value{1});
+ pop.add(Value{2});
+ pop.add(Value{3});
+ // Remove, then add
+ pop.remove(Value{1});
+ pop.add(Value{4});
+ ASSERT_VALUE_EQ(pop.getValue(), Value{sqrt(2 / 3.0)});
+}
+
+TEST_F(WindowFunctionStdDevTest, SampleRemove) {
+ samp.add(Value{1});
+ samp.add(Value{2});
+ samp.add(Value{3});
+ samp.remove(Value{1});
+ ASSERT_VALUE_EQ(samp.getValue(), Value{sqrt(0.5)});
+}
+
+TEST_F(WindowFunctionStdDevTest, NotDividingByZeroInM2Update) {
+ pop.add(Value{1});
+ pop.remove(Value{1});
+ pop.add(Value{1});
+ pop.add(Value{2});
+ ASSERT_VALUE_EQ(pop.getValue(), Value{0.5});
+
+ double nan = std::numeric_limits<double>::quiet_NaN();
+ samp.add(Value{nan});
+ samp.remove(Value{nan});
+ samp.add(Value{1});
+ samp.add(Value{2});
+ ASSERT_VALUE_EQ(samp.getValue(), Value{sqrt(0.5)});
+}
+
+TEST_F(WindowFunctionStdDevTest, HandlesNonfinite) {
+ double inf = std::numeric_limits<double>::infinity();
+ double nan = std::numeric_limits<double>::quiet_NaN();
+
+ pop.add(Value{1});
+ pop.add(Value{2});
+ pop.add(Value{inf});
+ ASSERT_VALUE_EQ(pop.getValue(), Value{nan}); // 1, 2, inf
+ pop.remove(Value{inf});
+ ASSERT_EQ(pop.getValue().getDouble(), 0.5); // 1, 2
+ pop.add(Value{nan});
+ ASSERT_VALUE_EQ(pop.getValue(), Value{nan}); // 1, 2, nan
+ pop.remove(Value{nan});
+ pop.add(Value{-inf});
+ ASSERT_VALUE_EQ(pop.getValue(), Value{nan}); // 1, 2, -inf
+ pop.add(Value{inf});
+ ASSERT_VALUE_EQ(pop.getValue(), Value{nan}); // 1, 2, -inf, inf
+ pop.add(Value{nan});
+ ASSERT_VALUE_EQ(pop.getValue(), Value{nan}); // 1, 2, -inf, inf, nan
+}
+
+TEST_F(WindowFunctionStdDevTest, Stability) {
+ const int collLength = 10000;
+ const int windowSize = 100;
+ PseudoRandom prng(0);
+ std::vector<double> vec(collLength);
+ for (int j = 0; j < collLength; j++) {
+ vec[j] = prng.nextCanonicalDouble() - 0.5;
+ }
+ for (int i = 0; i < windowSize; i++) {
+ pop.add(Value{vec[i]});
+ }
+ for (int i = windowSize; i < collLength; i++) {
+ pop.add(Value{vec[i]});
+ pop.remove(Value{vec[i - windowSize]});
+ double trueStdDev = stdDevPop(vec.begin() + i - windowSize + 1, vec.begin() + i + 1);
+ double calculatedStdDev = pop.getValue().getDouble();
+ ASSERT_LTE(Decimal128(calculatedStdDev).subtract(Decimal128(trueStdDev)).toAbs(),
+ Decimal128("1e-15"));
+ double relativeError = (calculatedStdDev - trueStdDev) / trueStdDev;
+ ASSERT_LTE(relativeError, 1e-15);
+ }
+}
+
+TEST_F(WindowFunctionStdDevTest, LargeNumberStability) {
+ const int collLength = 10000;
+ const int windowSize = 100;
+ PseudoRandom prng(0);
+ std::vector<double> vec(collLength);
+ for (int j = 0; j < collLength; j++) {
+ vec[j] = (prng.nextCanonicalDouble() - 0.5) * prng.nextInt64();
+ }
+ for (int i = 0; i < windowSize; i++) {
+ pop.add(Value{vec[i]});
+ }
+ for (int i = windowSize; i < collLength; i++) {
+ pop.add(Value{vec[i]});
+ pop.remove(Value{vec[i - windowSize]});
+ double trueStdDev = stdDevPop(vec.begin() + i - windowSize + 1, vec.begin() + i + 1);
+ double calculatedStdDev = pop.getValue().getDouble();
+ double relativeError = (calculatedStdDev - trueStdDev) / trueStdDev;
+ ASSERT_LTE(relativeError, 1e-15);
+ }
+}
+
+} // namespace
+} // namespace mongo