diff options
author | Hartek Sabharwal <hartek.sabharwal@mongodb.com> | 2021-03-15 19:02:14 +0000 |
---|---|---|
committer | Evergreen Agent <no-reply@evergreen.mongodb.com> | 2021-03-15 21:18:59 +0000 |
commit | ab6cfc3ad758c41d5e8bb9364ef8cb3639e84f0f (patch) | |
tree | 00a486d4efd451b0f4d62e27f0621015363bb419 | |
parent | c229494bfe372c2578b0a69f8ab1e0733fcffed0 (diff) | |
download | mongo-ab6cfc3ad758c41d5e8bb9364ef8cb3639e84f0f.tar.gz |
SERVER-55063 Window function StdDev can take the sqrt of a negative number and return NaN
3 files changed, 60 insertions, 1 deletions
diff --git a/src/mongo/db/pipeline/accumulator_std_dev.cpp b/src/mongo/db/pipeline/accumulator_std_dev.cpp index dc696e82ea5..e2208d92ff0 100644 --- a/src/mongo/db/pipeline/accumulator_std_dev.cpp +++ b/src/mongo/db/pipeline/accumulator_std_dev.cpp @@ -62,7 +62,7 @@ void AccumulatorStdDev::processInternal(const Value& input, bool merging) { const double val = input.getDouble(); // This is an implementation of the following algorithm: - // http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Online_algorithm + // https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_online_algorithm _count += 1; const double delta = val - _mean; if (delta != 0.0) { diff --git a/src/mongo/db/pipeline/window_function/window_function_std_dev_test.cpp b/src/mongo/db/pipeline/window_function/window_function_std_dev_test.cpp index dbb270a4387..efae0ecfb58 100644 --- a/src/mongo/db/pipeline/window_function/window_function_std_dev_test.cpp +++ b/src/mongo/db/pipeline/window_function/window_function_std_dev_test.cpp @@ -60,6 +60,13 @@ TEST_F(WindowFunctionStdDevTest, EmptyWindow) { ASSERT_VALUE_EQ(pop.getValue(), Value{BSONNULL}); } +TEST_F(WindowFunctionStdDevTest, SingletonWindow) { + pop.add(Value{1}); + ASSERT_VALUE_EQ(pop.getValue(), Value{0}); + samp.add(Value{1}); + ASSERT_VALUE_EQ(samp.getValue(), Value{BSONNULL}); +} + TEST_F(WindowFunctionStdDevTest, ReturnsDouble) { pop.add(Value{1}); pop.add(Value{2}); @@ -197,5 +204,47 @@ TEST_F(WindowFunctionStdDevTest, LargeNumberStability) { } } +TEST_F(WindowFunctionStdDevTest, HandlesUnderflow) { + double nan = std::numeric_limits<double>::quiet_NaN(); + const int collLength = 10000; + const int windowSize = 100; + PseudoRandom prng(0); + std::vector<double> vec(collLength); + for (int j = 0; j < collLength; j++) { + vec[j] = prng.nextCanonicalDouble() - 0.5; + } + for (int i = 0; i < collLength / windowSize; i++) { + // Fill up the window. Remove all but one element. The population std dev should now equal + // exactly 0 since there is only one element, but due to floating point error, the _m2 + // quantity might be a small negative value. Taking the sqrt of this in the std dev formula + // would then return NaN. + for (int j = 0; j < windowSize; j++) + pop.add(Value{vec[i * windowSize + j]}); + for (int k = 0; k < windowSize - 1; k++) + pop.remove(Value{vec[i * windowSize + k]}); + // NaN and -NaN are treated as equal when wrapped in a Value. + ASSERT_VALUE_NE(pop.getValue(), Value{nan}); + ASSERT_VALUE_EQ(pop.getValue(), Value{0}); + // Empty the window. + pop.remove(Value{vec[i * windowSize + (windowSize - 1)]}); + ASSERT_VALUE_EQ(pop.getValue(), Value{BSONNULL}); + } +} + +TEST_F(WindowFunctionStdDevTest, ConstantInput) { + const int collLength = 1000; + const int windowSize = 100; + PseudoRandom prng(0); + const double constant = prng.nextCanonicalDouble() - 0.5; + for (int i = 0; i < windowSize; i++) { + pop.add(Value{constant}); + } + for (int i = windowSize; i < collLength; i++) { + pop.add(Value{constant}); + pop.remove(Value{constant}); + ASSERT_VALUE_EQ(pop.getValue(), Value{0}); + } +} + } // namespace } // namespace mongo diff --git a/src/mongo/db/pipeline/window_function/window_function_stddev.h b/src/mongo/db/pipeline/window_function/window_function_stddev.h index 5b7b5f77839..bd2eed4a544 100644 --- a/src/mongo/db/pipeline/window_function/window_function_stddev.h +++ b/src/mongo/db/pipeline/window_function/window_function_stddev.h @@ -62,6 +62,16 @@ public: const long long adjustedCount = _isSamp ? _count - 1 : _count; if (adjustedCount == 0) return getDefault(); + double squaredDifferences = _m2->getValue(false).coerceToDouble(); + if (squaredDifferences < 0 || (!_isSamp && _count == 1)) { + // _m2 is the sum of squared differences from the mean, so it should always be + // nonnegative. It may take on a small negative value due to floating point error, which + // breaks the sqrt calculation. In this case, the closest valid value for _m2 is 0, so + // we reset _m2 and return 0 for the standard deviation. + // If we're doing a population std dev of one element, it is also correct to return 0. + _m2->reset(); + return Value{0}; + } return Value(sqrt(_m2->getValue(false).coerceToDouble() / adjustedCount)); } |