summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorHartek Sabharwal <hartek.sabharwal@mongodb.com>2021-03-15 19:02:14 +0000
committerEvergreen Agent <no-reply@evergreen.mongodb.com>2021-03-15 21:18:59 +0000
commitab6cfc3ad758c41d5e8bb9364ef8cb3639e84f0f (patch)
tree00a486d4efd451b0f4d62e27f0621015363bb419
parentc229494bfe372c2578b0a69f8ab1e0733fcffed0 (diff)
downloadmongo-ab6cfc3ad758c41d5e8bb9364ef8cb3639e84f0f.tar.gz
SERVER-55063 Window function StdDev can take the sqrt of a negative number and return NaN
-rw-r--r--src/mongo/db/pipeline/accumulator_std_dev.cpp2
-rw-r--r--src/mongo/db/pipeline/window_function/window_function_std_dev_test.cpp49
-rw-r--r--src/mongo/db/pipeline/window_function/window_function_stddev.h10
3 files changed, 60 insertions, 1 deletions
diff --git a/src/mongo/db/pipeline/accumulator_std_dev.cpp b/src/mongo/db/pipeline/accumulator_std_dev.cpp
index dc696e82ea5..e2208d92ff0 100644
--- a/src/mongo/db/pipeline/accumulator_std_dev.cpp
+++ b/src/mongo/db/pipeline/accumulator_std_dev.cpp
@@ -62,7 +62,7 @@ void AccumulatorStdDev::processInternal(const Value& input, bool merging) {
const double val = input.getDouble();
// This is an implementation of the following algorithm:
- // http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Online_algorithm
+ // https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_online_algorithm
_count += 1;
const double delta = val - _mean;
if (delta != 0.0) {
diff --git a/src/mongo/db/pipeline/window_function/window_function_std_dev_test.cpp b/src/mongo/db/pipeline/window_function/window_function_std_dev_test.cpp
index dbb270a4387..efae0ecfb58 100644
--- a/src/mongo/db/pipeline/window_function/window_function_std_dev_test.cpp
+++ b/src/mongo/db/pipeline/window_function/window_function_std_dev_test.cpp
@@ -60,6 +60,13 @@ TEST_F(WindowFunctionStdDevTest, EmptyWindow) {
ASSERT_VALUE_EQ(pop.getValue(), Value{BSONNULL});
}
+TEST_F(WindowFunctionStdDevTest, SingletonWindow) {
+ pop.add(Value{1});
+ ASSERT_VALUE_EQ(pop.getValue(), Value{0});
+ samp.add(Value{1});
+ ASSERT_VALUE_EQ(samp.getValue(), Value{BSONNULL});
+}
+
TEST_F(WindowFunctionStdDevTest, ReturnsDouble) {
pop.add(Value{1});
pop.add(Value{2});
@@ -197,5 +204,47 @@ TEST_F(WindowFunctionStdDevTest, LargeNumberStability) {
}
}
+TEST_F(WindowFunctionStdDevTest, HandlesUnderflow) {
+ double nan = std::numeric_limits<double>::quiet_NaN();
+ const int collLength = 10000;
+ const int windowSize = 100;
+ PseudoRandom prng(0);
+ std::vector<double> vec(collLength);
+ for (int j = 0; j < collLength; j++) {
+ vec[j] = prng.nextCanonicalDouble() - 0.5;
+ }
+ for (int i = 0; i < collLength / windowSize; i++) {
+ // Fill up the window. Remove all but one element. The population std dev should now equal
+ // exactly 0 since there is only one element, but due to floating point error, the _m2
+ // quantity might be a small negative value. Taking the sqrt of this in the std dev formula
+ // would then return NaN.
+ for (int j = 0; j < windowSize; j++)
+ pop.add(Value{vec[i * windowSize + j]});
+ for (int k = 0; k < windowSize - 1; k++)
+ pop.remove(Value{vec[i * windowSize + k]});
+ // NaN and -NaN are treated as equal when wrapped in a Value.
+ ASSERT_VALUE_NE(pop.getValue(), Value{nan});
+ ASSERT_VALUE_EQ(pop.getValue(), Value{0});
+ // Empty the window.
+ pop.remove(Value{vec[i * windowSize + (windowSize - 1)]});
+ ASSERT_VALUE_EQ(pop.getValue(), Value{BSONNULL});
+ }
+}
+
+TEST_F(WindowFunctionStdDevTest, ConstantInput) {
+ const int collLength = 1000;
+ const int windowSize = 100;
+ PseudoRandom prng(0);
+ const double constant = prng.nextCanonicalDouble() - 0.5;
+ for (int i = 0; i < windowSize; i++) {
+ pop.add(Value{constant});
+ }
+ for (int i = windowSize; i < collLength; i++) {
+ pop.add(Value{constant});
+ pop.remove(Value{constant});
+ ASSERT_VALUE_EQ(pop.getValue(), Value{0});
+ }
+}
+
} // namespace
} // namespace mongo
diff --git a/src/mongo/db/pipeline/window_function/window_function_stddev.h b/src/mongo/db/pipeline/window_function/window_function_stddev.h
index 5b7b5f77839..bd2eed4a544 100644
--- a/src/mongo/db/pipeline/window_function/window_function_stddev.h
+++ b/src/mongo/db/pipeline/window_function/window_function_stddev.h
@@ -62,6 +62,16 @@ public:
const long long adjustedCount = _isSamp ? _count - 1 : _count;
if (adjustedCount == 0)
return getDefault();
+ double squaredDifferences = _m2->getValue(false).coerceToDouble();
+ if (squaredDifferences < 0 || (!_isSamp && _count == 1)) {
+ // _m2 is the sum of squared differences from the mean, so it should always be
+ // nonnegative. It may take on a small negative value due to floating point error, which
+ // breaks the sqrt calculation. In this case, the closest valid value for _m2 is 0, so
+ // we reset _m2 and return 0 for the standard deviation.
+ // If we're doing a population std dev of one element, it is also correct to return 0.
+ _m2->reset();
+ return Value{0};
+ }
return Value(sqrt(_m2->getValue(false).coerceToDouble() / adjustedCount));
}