summaryrefslogtreecommitdiff
path: root/src/mongo/platform/random.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/mongo/platform/random.cpp')
-rw-r--r--src/mongo/platform/random.cpp198
1 files changed, 103 insertions, 95 deletions
diff --git a/src/mongo/platform/random.cpp b/src/mongo/platform/random.cpp
index f12ffb57a5d..9ed392adf1c 100644
--- a/src/mongo/platform/random.cpp
+++ b/src/mongo/platform/random.cpp
@@ -39,75 +39,59 @@
#include <bcrypt.h>
#else
#include <errno.h>
+#include <fcntl.h>
#endif
#define _CRT_RAND_S
+#include <array>
#include <cstdlib>
#include <fstream>
#include <iostream>
#include <limits>
#include <memory>
+#include <random>
#include "mongo/util/assert_util.h"
#include "mongo/util/log.h"
-namespace mongo {
-
-// ---- PseudoRandom -----
+#ifdef _WIN32
+#define SECURE_RANDOM_BCRYPT
+#elif defined(__linux__) || defined(__sun) || defined(__APPLE__) || defined(__FreeBSD__) || \
+ defined(__EMSCRIPTEN__)
+#define SECURE_RANDOM_URANDOM
+#elif defined(__OpenBSD__)
+#define SECURE_RANDOM_ARCFOUR
+#else
+#error "Must implement SecureRandom for platform"
+#endif
-uint32_t PseudoRandom::nextUInt32() {
- uint32_t t = _x ^ (_x << 11);
- _x = _y;
- _y = _z;
- _z = _w;
- return _w = _w ^ (_w >> 19) ^ (t ^ (t >> 8));
-}
+namespace mongo {
namespace {
-const uint32_t default_y = 362436069;
-const uint32_t default_z = 521288629;
-const uint32_t default_w = 88675123;
-} // namespace
-
-PseudoRandom::PseudoRandom(uint32_t seed) {
- _x = seed;
- _y = default_y;
- _z = default_z;
- _w = default_w;
-}
-PseudoRandom::PseudoRandom(int32_t seed) : PseudoRandom(static_cast<uint32_t>(seed)) {}
-
-PseudoRandom::PseudoRandom(int64_t seed)
- : PseudoRandom(static_cast<uint32_t>(seed >> 32) ^ static_cast<uint32_t>(seed)) {}
-
-int32_t PseudoRandom::nextInt32() {
- return nextUInt32();
-}
-
-int64_t PseudoRandom::nextInt64() {
- uint64_t a = nextUInt32();
- uint64_t b = nextUInt32();
- return (a << 32) | b;
-}
-
-double PseudoRandom::nextCanonicalDouble() {
- double result;
- do {
- auto generated = static_cast<uint64_t>(nextInt64());
- result = static_cast<double>(generated) / std::numeric_limits<uint64_t>::max();
- } while (result == 1.0);
- return result;
-}
-
-// --- SecureRandom ----
+template <size_t N>
+struct Buffer {
+ uint64_t pop() {
+ return arr[--avail];
+ }
+ uint8_t* fillPtr() {
+ return reinterpret_cast<uint8_t*>(arr.data() + avail);
+ }
+ size_t fillSize() {
+ return sizeof(uint64_t) * (arr.size() - avail);
+ }
+ void setFilled() {
+ avail = arr.size();
+ }
-SecureRandom::~SecureRandom() {}
+ std::array<uint64_t, N> arr;
+ size_t avail = 0;
+};
-#ifdef _WIN32
-class WinSecureRandom : public SecureRandom {
+#if defined(SECURE_RANDOM_BCRYPT)
+class Source {
public:
- WinSecureRandom() {
+ Source() {
auto ntstatus = ::BCryptOpenAlgorithmProvider(
&_algHandle, BCRYPT_RNG_ALGORITHM, MS_PRIMITIVE_PROVIDER, 0);
if (ntstatus != STATUS_SUCCESS) {
@@ -118,7 +102,7 @@ public:
}
}
- virtual ~WinSecureRandom() {
+ ~Source() {
auto ntstatus = ::BCryptCloseAlgorithmProvider(_algHandle, 0);
if (ntstatus != STATUS_SUCCESS) {
warning() << "Failed to close crypto algorithm provider destroying secure random "
@@ -127,75 +111,99 @@ public:
}
}
- int64_t nextInt64() {
- int64_t value;
- auto ntstatus =
- ::BCryptGenRandom(_algHandle, reinterpret_cast<PUCHAR>(&value), sizeof(value), 0);
+ size_t refill(uint8_t* buf, size_t n) {
+ auto ntstatus = ::BCryptGenRandom(_algHandle, reinterpret_cast<PUCHAR>(buf), n, 0);
if (ntstatus != STATUS_SUCCESS) {
error() << "Failed to generate random number from secure random object; NTSTATUS: "
<< ntstatus;
fassertFailed(28814);
}
- return value;
+ return n;
}
private:
BCRYPT_ALG_HANDLE _algHandle;
};
+#endif // SECURE_RANDOM_BCRYPT
-std::unique_ptr<SecureRandom> SecureRandom::create() {
- return std::make_unique<WinSecureRandom>();
-}
-
-#elif defined(__linux__) || defined(__sun) || defined(__APPLE__) || defined(__FreeBSD__) || \
- defined(__EMSCRIPTEN__)
-
-class InputStreamSecureRandom : public SecureRandom {
+#if defined(SECURE_RANDOM_URANDOM)
+class Source {
public:
- InputStreamSecureRandom(const char* fn) {
- _in = std::make_unique<std::ifstream>(fn, std::ios::binary | std::ios::in);
- if (!_in->is_open()) {
- error() << "cannot open " << fn << " " << strerror(errno);
- fassertFailed(28839);
- }
- }
-
- int64_t nextInt64() {
- int64_t r;
- _in->read(reinterpret_cast<char*>(&r), sizeof(r));
- if (_in->fail()) {
- error() << "InputStreamSecureRandom failed to generate random bytes";
- fassertFailed(28840);
+ size_t refill(uint8_t* buf, size_t n) {
+ size_t i = 0;
+ while (i < n) {
+ ssize_t r;
+ while ((r = read(sharedFd(), buf + i, n - i)) == -1) {
+ if (errno == EINTR) {
+ continue;
+ } else {
+ auto errSave = errno;
+ error() << "SecureRandom: read `" << kFn << "`: " << strerror(errSave);
+ fassertFailed(28840);
+ }
+ }
+ i += r;
}
- return r;
+ return i;
}
private:
- std::unique_ptr<std::ifstream> _in;
+ static constexpr const char* kFn = "/dev/urandom";
+ static int sharedFd() {
+ // Retain the urandom fd forever.
+ // Kernel ensures that concurrent `read` calls don't mingle their data.
+ // http://lkml.iu.edu//hypermail/linux/kernel/0412.1/0181.html
+ static const int fd = [] {
+ int f;
+ while ((f = open(kFn, 0)) == -1) {
+ if (errno == EINTR) {
+ continue;
+ } else {
+ auto errSave = errno;
+ error() << "SecureRandom: open `" << kFn << "`: " << strerror(errSave);
+ fassertFailed(28839);
+ }
+ }
+ return f;
+ }();
+ return fd;
+ }
};
+#endif // SECURE_RANDOM_URANDOM
-std::unique_ptr<SecureRandom> SecureRandom::create() {
- return std::make_unique<InputStreamSecureRandom>("/dev/urandom");
-}
+#if defined(SECURE_RANDOM_ARCFOUR)
+class Source {
+public:
+ size_t refill(uint8_t* buf, size_t n) {
+ arc4random_buf(buf, n);
+ return n;
+ }
+};
+#endif // SECURE_RANDOM_ARCFOUR
-#elif defined(__OpenBSD__)
+} // namespace
-class Arc4SecureRandom : public SecureRandom {
+class SecureUrbg::State {
public:
- int64_t nextInt64() {
- int64_t value;
- arc4random_buf(&value, sizeof(value));
- return value;
+ uint64_t get() {
+ if (!_buffer.avail) {
+ size_t n = _source.refill(_buffer.fillPtr(), _buffer.fillSize());
+ _buffer.avail += n / sizeof(uint64_t);
+ }
+ return _buffer.pop();
}
+
+private:
+ Source _source;
+ Buffer<16> _buffer;
};
-std::unique_ptr<SecureRandom> SecureRandom::create() {
- return std::make_unique<Arc4SecureRandom>();
-}
+SecureUrbg::SecureUrbg() : _state{std::make_unique<State>()} {}
-#else
+SecureUrbg::~SecureUrbg() = default;
-#error Must implement SecureRandom for platform
+uint64_t SecureUrbg::operator()() {
+ return _state->get();
+}
-#endif
} // namespace mongo