summaryrefslogtreecommitdiff
path: root/src/mongo/util/fail_point.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/mongo/util/fail_point.cpp')
-rw-r--r--src/mongo/util/fail_point.cpp40
1 files changed, 9 insertions, 31 deletions
diff --git a/src/mongo/util/fail_point.cpp b/src/mongo/util/fail_point.cpp
index 60050da7055..af757e2bd49 100644
--- a/src/mongo/util/fail_point.cpp
+++ b/src/mongo/util/fail_point.cpp
@@ -33,7 +33,10 @@
#include "mongo/util/fail_point.h"
#include <fmt/format.h>
+
+#include <limits>
#include <memory>
+#include <random>
#include "mongo/base/init.h"
#include "mongo/bson/json.h"
@@ -58,38 +61,13 @@ MONGO_INITIALIZER_GENERAL(AllFailPointsRegistered, (), ())
return Status::OK();
}
-/**
- * Type representing the per-thread PRNG used by fail-points.
- */
-class FailPointPRNG {
-public:
- FailPointPRNG() : _prng(std::unique_ptr<SecureRandom>(SecureRandom::create())->nextInt64()) {}
-
- void resetSeed(int32_t seed) {
- _prng = PseudoRandom(seed);
- }
-
- int32_t nextPositiveInt32() {
- return _prng.nextInt32() & ~(1 << 31);
- }
-
- static FailPointPRNG* current() {
- if (!_failPointPrng)
- _failPointPrng = std::make_unique<FailPointPRNG>();
- return _failPointPrng.get();
- }
-
-private:
- PseudoRandom _prng;
- static thread_local std::unique_ptr<FailPointPRNG> _failPointPrng;
-};
-
-thread_local std::unique_ptr<FailPointPRNG> FailPointPRNG::_failPointPrng;
+/** The per-thread PRNG used by fail-points. */
+thread_local PseudoRandom threadPrng{SecureRandom().nextInt64()};
} // namespace
void FailPoint::setThreadPRNGSeed(int32_t seed) {
- FailPointPRNG::current()->resetSeed(seed);
+ threadPrng = PseudoRandom(seed);
}
FailPoint::FailPoint() = default;
@@ -155,10 +133,10 @@ FailPoint::RetCode FailPoint::_slowShouldFailOpenBlock(
case alwaysOn:
return slowOn;
case random: {
- const int maxActivationValue = _timesOrPeriod.load();
- if (FailPointPRNG::current()->nextPositiveInt32() < maxActivationValue)
+ std::uniform_int_distribution<int> distribution{};
+ if (distribution(threadPrng.urbg()) < _timesOrPeriod.load()) {
return slowOn;
-
+ }
return slowOff;
}
case nTimes: {